Skip to content

API Reference

This page contains the API documentation for all Python modules in the codebase (excluding init.py files).

aiperf.cli

Main CLI entry point for the AIPerf system.

analyze(user_config, service_config=None)

Sweep through one or more parameters.

Source code in aiperf/cli.py
44
45
46
47
48
49
50
51
52
@app.command(name="analyze")
def analyze(
    user_config: UserConfig,
    service_config: ServiceConfig | None = None,
) -> None:
    """Sweep through one or more parameters."""
    # TODO: Implement this

    warn_command_not_implemented("analyze")

create_template(template_filename=CLIDefaults.TEMPLATE_FILENAME)

Create a template configuration file.

Source code in aiperf/cli.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@app.command(name="create-template")
def create_template(
    template_filename: Annotated[
        str,
        Field(
            description=f"Path to the template file. Defaults to {CLIDefaults.TEMPLATE_FILENAME}."
        ),
        Parameter(
            name=("--template-filename", "-t"),
        ),
    ] = CLIDefaults.TEMPLATE_FILENAME,
) -> None:
    """Create a template configuration file."""
    # TODO: Implement this

    warn_command_not_implemented("create-template")

profile(user_config, service_config=None)

Run the Profile subcommand.

Parameters:

Name Type Description Default
user_config UserConfig

User configuration for the benchmark

required
service_config ServiceConfig | None

Service configuration options

None
Source code in aiperf/cli.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@app.command(name="profile")
def profile(
    user_config: UserConfig,
    service_config: ServiceConfig | None = None,
) -> None:
    """Run the Profile subcommand.

    Args:
        user_config: User configuration for the benchmark
        service_config: Service configuration options
    """
    with exit_on_error(title="Error Running AIPerf System"):
        from aiperf.cli_runner import run_system_controller
        from aiperf.common.config import load_service_config

        service_config = service_config or load_service_config()

        run_system_controller(user_config, service_config)

validate_config(user_config=None, service_config=None)

Validate the configuration file.

Source code in aiperf/cli.py
73
74
75
76
77
78
79
80
81
@app.command(name="validate-config")
def validate_config(
    user_config: UserConfig | None = None,
    service_config: ServiceConfig | None = None,
) -> None:
    """Validate the configuration file."""
    # TODO: Implement this

    warn_command_not_implemented("validate-config")

aiperf.cli_runner

run_system_controller(user_config, service_config)

Run the system controller with the given configuration.

Source code in aiperf/cli_runner.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def run_system_controller(
    user_config: UserConfig,
    service_config: ServiceConfig,
) -> None:
    """Run the system controller with the given configuration."""

    from aiperf.common.aiperf_logger import AIPerfLogger
    from aiperf.common.bootstrap import bootstrap_and_run_service
    from aiperf.controller import SystemController
    from aiperf.module_loader import ensure_modules_loaded

    logger = AIPerfLogger(__name__)

    log_queue = None
    if service_config.disable_ui:
        from aiperf.common.logging import setup_rich_logging

        setup_rich_logging(user_config, service_config)

    # Create and start the system controller
    logger.info("Starting AIPerf System")

    try:
        ensure_modules_loaded()
    except Exception as e:
        raise_startup_error_and_exit(
            f"Error loading modules: {e}",
            title="Error Loading Modules",
        )

    try:
        bootstrap_and_run_service(
            SystemController,
            service_id="system_controller",
            service_config=service_config,
            user_config=user_config,
            log_queue=log_queue,
        )
    except Exception:
        logger.exception("Error running AIPerf System")
        raise
    finally:
        logger.debug("AIPerf System exited")

aiperf.cli_utils

exit_on_error

Bases: AbstractContextManager

Context manager that exits the program if an error occurs.

Parameters:

Name Type Description Default
*exceptions type[BaseException]

The exceptions to exit on. If no exceptions are provided, all exceptions will be caught.

()
message RenderableType

The message to display. Can be a string or a rich renderable. Will be formatted with the exception as {e}.

'{e}'
text_color StyleType

The text color to use.

'bold red'
title str

The title of the error.

'Error'
exit_code int

The exit code to use.

1
Source code in aiperf/cli_utils.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class exit_on_error(AbstractContextManager):
    """Context manager that exits the program if an error occurs.

    Args:
        *exceptions: The exceptions to exit on. If no exceptions are provided, all exceptions will be caught.
        message: The message to display. Can be a string or a rich renderable. Will be formatted with the exception as `{e}`.
        text_color: The text color to use.
        title: The title of the error.
        exit_code: The exit code to use.
    """

    def __init__(
        self,
        *exceptions: type[BaseException],
        message: "RenderableType" = "{e}",
        text_color: "StyleType" = "bold red",
        title: str = "Error",
        exit_code: int = 1,
    ):
        self.message: RenderableType = message
        self.text_color: StyleType = text_color
        self.title: str = title
        self.exit_code: int = exit_code
        self.exceptions: tuple[type[BaseException], ...] = exceptions

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is None:
            return

        if (
            not self.exceptions
            and not isinstance(exc_value, (SystemExit | KeyboardInterrupt))
        ) or issubclass(exc_type, self.exceptions):
            message = (
                self.message.format(e=exc_value)
                if isinstance(self.message, str)
                else self.message
            )
            raise_startup_error_and_exit(
                message,
                text_color=self.text_color,
                title=self.title,
                exit_code=self.exit_code,
            )

raise_startup_error_and_exit(message, text_color='bold red', title='Error', exit_code=1, border_style='red', title_align='left')

Raise a startup error and exit the program.

Parameters:

Name Type Description Default
message RenderableType

The message to display. Can be a string or a rich renderable.

required
text_color StyleType

The text color to use.

'bold red'
title str

The title of the error.

'Error'
exit_code int

The exit code to use.

1
border_style StyleType

The border style to use.

'red'
title_align AlignMethod

The alignment of the title.

'left'
Source code in aiperf/cli_utils.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def raise_startup_error_and_exit(
    message: "RenderableType",
    text_color: "StyleType" = "bold red",
    title: str = "Error",
    exit_code: int = 1,
    border_style: "StyleType" = "red",
    title_align: "AlignMethod" = "left",
) -> None:
    """Raise a startup error and exit the program.

    Args:
        message: The message to display. Can be a string or a rich renderable.
        text_color: The text color to use.
        title: The title of the error.
        exit_code: The exit code to use.
        border_style: The border style to use.
        title_align: The alignment of the title.
    """
    import sys

    from rich.console import Console
    from rich.panel import Panel

    if isinstance(message, str):
        message = f"[{text_color}]{message}[/{text_color}]" if text_color else message

    console = Console()
    console.print(
        Panel(
            renderable=message,
            title=title,
            title_align=title_align,
            border_style=border_style,
        )
    )

    sys.exit(exit_code)

warn_cancelled_early()

Warn the user that the profile run was cancelled early.

Source code in aiperf/cli_utils.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def warn_cancelled_early() -> None:
    """Warn the user that the profile run was cancelled early."""
    from rich.console import Console
    from rich.panel import Panel
    from rich.text import Text

    console = Console()
    console.print("\n")
    console.print(
        Panel(
            Text(
                "The profile run was cancelled early. Results shown may be incomplete or inaccurate.",
                style="yellow",
            ),
            border_style="bold yellow",
            title="Warning: Profile Run Cancelled Early",
            title_align="left",
        )
    )
    console.file.flush()

warn_command_not_implemented(command)

Warn the user that the subcommand is not implemented.

Source code in aiperf/cli_utils.py
14
15
16
17
18
19
def warn_command_not_implemented(command: str) -> None:
    """Warn the user that the subcommand is not implemented."""
    raise_startup_error_and_exit(
        f"Command [bold red]{command}[/bold red] is not yet implemented",
        title="Not Implemented",
    )

aiperf.clients.http.aiohttp_client

AioHttpClientMixin

Bases: AIPerfLoggerMixin

A high-performance HTTP client for communicating with HTTP based REST APIs using aiohttp.

This class is optimized for maximum performance and accurate timing measurements, making it ideal for benchmarking scenarios.

Source code in aiperf/clients/http/aiohttp_client.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class AioHttpClientMixin(AIPerfLoggerMixin):
    """A high-performance HTTP client for communicating with HTTP based REST APIs using aiohttp.

    This class is optimized for maximum performance and accurate timing measurements,
    making it ideal for benchmarking scenarios.
    """

    def __init__(self, model_endpoint: ModelEndpointInfo, **kwargs) -> None:
        self.model_endpoint = model_endpoint
        super().__init__(model_endpoint=model_endpoint, **kwargs)
        self.tcp_connector = create_tcp_connector()

        # For now, just set all timeouts to the same value.
        # TODO: Add support for different timeouts for different parts of the request.
        self.timeout = aiohttp.ClientTimeout(
            total=self.model_endpoint.endpoint.timeout,
            connect=self.model_endpoint.endpoint.timeout,
            sock_connect=self.model_endpoint.endpoint.timeout,
            sock_read=self.model_endpoint.endpoint.timeout,
            ceil_threshold=self.model_endpoint.endpoint.timeout,
        )

    async def close(self) -> None:
        """Close the client."""
        if self.tcp_connector:
            await self.tcp_connector.close()
            self.tcp_connector = None

    async def post_request(
        self,
        url: str,
        payload: str,
        headers: dict[str, str],
        **kwargs: Any,
    ) -> RequestRecord:
        """Send a streaming or non-streaming POST request to the specified URL with the given payload and headers.

        If the response is an SSE stream, the response will be parsed into a list of SSE messages.
        Otherwise, the response will be parsed into a TextResponse object.
        """

        self.debug(lambda: f"Sending POST request to {url}")

        record: RequestRecord = RequestRecord(
            start_perf_ns=time.perf_counter_ns(),
        )

        try:
            # Make raw HTTP request with precise timing using aiohttp
            async with aiohttp.ClientSession(
                connector=self.tcp_connector,
                timeout=self.timeout,
                headers=headers,
                skip_auto_headers=[
                    *list(headers.keys()),
                    "User-Agent",
                    "Accept-Encoding",
                ],
                connector_owner=False,
            ) as session:
                record.start_perf_ns = time.perf_counter_ns()
                async with session.post(
                    url, data=payload, headers=headers, **kwargs
                ) as response:
                    record.status = response.status
                    # Check for HTTP errors
                    if response.status != 200:
                        error_text = await response.text()
                        record.error = ErrorDetails(
                            code=response.status,
                            type=response.reason,
                            message=error_text,
                        )
                        return record

                    record.recv_start_perf_ns = time.perf_counter_ns()

                    if response.content_type == "text/event-stream":
                        # Parse SSE stream with optimal performance
                        messages = await AioHttpSSEStreamReader(
                            response
                        ).read_complete_stream()
                        record.responses.extend(messages)
                    else:
                        raw_response = await response.text()
                        record.end_perf_ns = time.perf_counter_ns()
                        record.responses.append(
                            TextResponse(
                                perf_ns=record.end_perf_ns,
                                content_type=response.content_type,
                                text=raw_response,
                            )
                        )
                    record.end_perf_ns = time.perf_counter_ns()

        except Exception as e:
            record.end_perf_ns = time.perf_counter_ns()
            self.error(f"Error in aiohttp request: {e}")
            record.error = ErrorDetails(type=e.__class__.__name__, message=str(e))

        return record

close() async

Close the client.

Source code in aiperf/clients/http/aiohttp_client.py
49
50
51
52
53
async def close(self) -> None:
    """Close the client."""
    if self.tcp_connector:
        await self.tcp_connector.close()
        self.tcp_connector = None

post_request(url, payload, headers, **kwargs) async

Send a streaming or non-streaming POST request to the specified URL with the given payload and headers.

If the response is an SSE stream, the response will be parsed into a list of SSE messages. Otherwise, the response will be parsed into a TextResponse object.

Source code in aiperf/clients/http/aiohttp_client.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
async def post_request(
    self,
    url: str,
    payload: str,
    headers: dict[str, str],
    **kwargs: Any,
) -> RequestRecord:
    """Send a streaming or non-streaming POST request to the specified URL with the given payload and headers.

    If the response is an SSE stream, the response will be parsed into a list of SSE messages.
    Otherwise, the response will be parsed into a TextResponse object.
    """

    self.debug(lambda: f"Sending POST request to {url}")

    record: RequestRecord = RequestRecord(
        start_perf_ns=time.perf_counter_ns(),
    )

    try:
        # Make raw HTTP request with precise timing using aiohttp
        async with aiohttp.ClientSession(
            connector=self.tcp_connector,
            timeout=self.timeout,
            headers=headers,
            skip_auto_headers=[
                *list(headers.keys()),
                "User-Agent",
                "Accept-Encoding",
            ],
            connector_owner=False,
        ) as session:
            record.start_perf_ns = time.perf_counter_ns()
            async with session.post(
                url, data=payload, headers=headers, **kwargs
            ) as response:
                record.status = response.status
                # Check for HTTP errors
                if response.status != 200:
                    error_text = await response.text()
                    record.error = ErrorDetails(
                        code=response.status,
                        type=response.reason,
                        message=error_text,
                    )
                    return record

                record.recv_start_perf_ns = time.perf_counter_ns()

                if response.content_type == "text/event-stream":
                    # Parse SSE stream with optimal performance
                    messages = await AioHttpSSEStreamReader(
                        response
                    ).read_complete_stream()
                    record.responses.extend(messages)
                else:
                    raw_response = await response.text()
                    record.end_perf_ns = time.perf_counter_ns()
                    record.responses.append(
                        TextResponse(
                            perf_ns=record.end_perf_ns,
                            content_type=response.content_type,
                            text=raw_response,
                        )
                    )
                record.end_perf_ns = time.perf_counter_ns()

    except Exception as e:
        record.end_perf_ns = time.perf_counter_ns()
        self.error(f"Error in aiohttp request: {e}")
        record.error = ErrorDetails(type=e.__class__.__name__, message=str(e))

    return record

AioHttpSSEStreamReader

A helper class for reading an SSE stream from an aiohttp.ClientResponse object.

This class is optimized for maximum performance and accurate timing measurements, making it ideal for benchmarking scenarios.

Source code in aiperf/clients/http/aiohttp_client.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class AioHttpSSEStreamReader:
    """A helper class for reading an SSE stream from an aiohttp.ClientResponse object.

    This class is optimized for maximum performance and accurate timing measurements,
    making it ideal for benchmarking scenarios.
    """

    def __init__(self, response: aiohttp.ClientResponse):
        self.response = response

    async def read_complete_stream(self) -> list[SSEMessage]:
        """Read the complete SSE stream in a performant manner and return a list of
        SSE messages that contain the most accurate timestamp data possible.

        Returns:
            A list of SSE messages.
        """
        messages: list[SSEMessage] = []

        async for raw_message, first_byte_ns in self.__aiter__():
            # Parse the raw SSE message into a SSEMessage object
            message = parse_sse_message(raw_message, first_byte_ns)
            messages.append(message)

        return messages

    async def __aiter__(self) -> typing.AsyncIterator[tuple[str, int]]:
        """Iterate over the SSE stream in a performant manner and return a tuple of the
        raw SSE message, the perf_counter_ns of the first byte, and the perf_counter_ns of the last byte.
        This provides the most accurate timing information possible without any delays due to the nature of
        the aiohttp library. The first byte is read immediately to capture the timestamp of the first byte,
        and the last byte is read after the rest of the chunk is read to capture the timestamp of the last byte.

        Returns:
            An async iterator of tuples of the raw SSE message, and the perf_counter_ns of the first byte
        """

        while not self.response.content.at_eof():
            # Read the first byte of the SSE stream
            first_byte = await self.response.content.read(1)
            chunk_ns_first_byte = time.perf_counter_ns()
            if not first_byte:
                break

            chunk = await self.response.content.readuntil(b"\n\n")

            if not chunk:
                break
            chunk = first_byte + chunk

            try:
                # Use the fastest available decoder
                yield (
                    chunk.decode("utf-8").strip(),
                    chunk_ns_first_byte,
                )
            except UnicodeDecodeError:
                # Handle potential encoding issues gracefully
                yield (
                    chunk.decode("utf-8", errors="replace").strip(),
                    chunk_ns_first_byte,
                )

__aiter__() async

Iterate over the SSE stream in a performant manner and return a tuple of the raw SSE message, the perf_counter_ns of the first byte, and the perf_counter_ns of the last byte. This provides the most accurate timing information possible without any delays due to the nature of the aiohttp library. The first byte is read immediately to capture the timestamp of the first byte, and the last byte is read after the rest of the chunk is read to capture the timestamp of the last byte.

Returns:

Type Description
AsyncIterator[tuple[str, int]]

An async iterator of tuples of the raw SSE message, and the perf_counter_ns of the first byte

Source code in aiperf/clients/http/aiohttp_client.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
async def __aiter__(self) -> typing.AsyncIterator[tuple[str, int]]:
    """Iterate over the SSE stream in a performant manner and return a tuple of the
    raw SSE message, the perf_counter_ns of the first byte, and the perf_counter_ns of the last byte.
    This provides the most accurate timing information possible without any delays due to the nature of
    the aiohttp library. The first byte is read immediately to capture the timestamp of the first byte,
    and the last byte is read after the rest of the chunk is read to capture the timestamp of the last byte.

    Returns:
        An async iterator of tuples of the raw SSE message, and the perf_counter_ns of the first byte
    """

    while not self.response.content.at_eof():
        # Read the first byte of the SSE stream
        first_byte = await self.response.content.read(1)
        chunk_ns_first_byte = time.perf_counter_ns()
        if not first_byte:
            break

        chunk = await self.response.content.readuntil(b"\n\n")

        if not chunk:
            break
        chunk = first_byte + chunk

        try:
            # Use the fastest available decoder
            yield (
                chunk.decode("utf-8").strip(),
                chunk_ns_first_byte,
            )
        except UnicodeDecodeError:
            # Handle potential encoding issues gracefully
            yield (
                chunk.decode("utf-8", errors="replace").strip(),
                chunk_ns_first_byte,
            )

read_complete_stream() async

Read the complete SSE stream in a performant manner and return a list of SSE messages that contain the most accurate timestamp data possible.

Returns:

Type Description
list[SSEMessage]

A list of SSE messages.

Source code in aiperf/clients/http/aiohttp_client.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
async def read_complete_stream(self) -> list[SSEMessage]:
    """Read the complete SSE stream in a performant manner and return a list of
    SSE messages that contain the most accurate timestamp data possible.

    Returns:
        A list of SSE messages.
    """
    messages: list[SSEMessage] = []

    async for raw_message, first_byte_ns in self.__aiter__():
        # Parse the raw SSE message into a SSEMessage object
        message = parse_sse_message(raw_message, first_byte_ns)
        messages.append(message)

    return messages

create_tcp_connector(**kwargs)

Create a new connector with the given configuration.

Source code in aiperf/clients/http/aiohttp_client.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def create_tcp_connector(**kwargs) -> aiohttp.TCPConnector:
    """Create a new connector with the given configuration."""

    def socket_factory(addr_info):
        """Custom socket factory optimized for SSE streaming performance."""
        family, sock_type, proto, _, _ = addr_info
        sock = socket.socket(family=family, type=sock_type, proto=proto)
        SocketDefaults.apply_to_socket(sock)
        return sock

    default_kwargs: dict[str, Any] = {
        "limit": AioHttpDefaults.LIMIT,
        "limit_per_host": AioHttpDefaults.LIMIT_PER_HOST,
        "ttl_dns_cache": AioHttpDefaults.TTL_DNS_CACHE,
        "use_dns_cache": AioHttpDefaults.USE_DNS_CACHE,
        "enable_cleanup_closed": AioHttpDefaults.ENABLE_CLEANUP_CLOSED,
        "force_close": AioHttpDefaults.FORCE_CLOSE,
        "keepalive_timeout": AioHttpDefaults.KEEPALIVE_TIMEOUT,
        "happy_eyeballs_delay": AioHttpDefaults.HAPPY_EYEBALLS_DELAY,
        "family": AioHttpDefaults.SOCKET_FAMILY,
        "socket_factory": socket_factory,
    }

    default_kwargs.update(kwargs)

    return aiohttp.TCPConnector(
        **default_kwargs,
    )

parse_sse_message(raw_message, perf_ns)

Parse a raw SSE message into an SSEMessage object.

Parsing logic based on official HTML SSE Living Standard: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream

Source code in aiperf/clients/http/aiohttp_client.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def parse_sse_message(raw_message: str, perf_ns: int) -> SSEMessage:
    """Parse a raw SSE message into an SSEMessage object.

    Parsing logic based on official HTML SSE Living Standard:
    https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
    """

    message = SSEMessage(perf_ns=perf_ns)
    for line in raw_message.split("\n"):
        if not (line := line.strip()):
            continue

        parts = line.split(":", 1)
        if len(parts) < 2:
            # Fields without a colon have no value, so the whole line is the field name
            message.packets.append(SSEField(name=parts[0].strip(), value=None))
            continue

        field_name, value = parts

        if field_name == "":
            # Field name is empty, so this is a comment
            field_name = SSEFieldType.COMMENT

        message.packets.append(SSEField(name=field_name.strip(), value=value.strip()))

    return message

aiperf.clients.http.defaults

AioHttpDefaults dataclass

Default values for aiohttp.ClientSession.

Source code in aiperf/clients/http/defaults.py
62
63
64
65
66
67
68
69
70
71
72
73
74
@dataclass(frozen=True)
class AioHttpDefaults:
    """Default values for aiohttp.ClientSession."""

    LIMIT = 2500  # Maximum number of concurrent connections
    LIMIT_PER_HOST = 2500  # Maximum number of concurrent connections per host
    TTL_DNS_CACHE = 300  # Time to live for DNS cache
    USE_DNS_CACHE = True  # Enable DNS cache
    ENABLE_CLEANUP_CLOSED = False  # Disable cleanup of closed connections
    FORCE_CLOSE = False  # Disable force close connections
    KEEPALIVE_TIMEOUT = 300  # Keepalive timeout
    HAPPY_EYEBALLS_DELAY = None  # Happy eyeballs delay (None = disabled)
    SOCKET_FAMILY = socket.AF_INET  # Family of the socket (IPv4)

SocketDefaults dataclass

Default values for socket options.

Source code in aiperf/clients/http/defaults.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@dataclass(frozen=True)
class SocketDefaults:
    """
    Default values for socket options.
    """

    TCP_NODELAY = 1  # Disable Nagle's algorithm
    TCP_QUICKACK = 1  # Quick ACK mode

    SO_KEEPALIVE = 1  # Enable keepalive
    TCP_KEEPIDLE = 60  # Start keepalive after 1 min idle
    TCP_KEEPINTVL = 30  # Keepalive interval: 30 seconds
    TCP_KEEPCNT = 1  # 1 failed keepalive probes = dead

    SO_LINGER = 0  # Disable linger
    SO_REUSEADDR = 1  # Enable reuse address
    SO_REUSEPORT = 1  # Enable reuse port

    SO_RCVBUF = 1024 * 1024 * 10  # 10MB receive buffer
    SO_SNDBUF = 1024 * 1024 * 10  # 10MB send buffer

    SO_RCVTIMEO = 30  # 30 second receive timeout
    SO_SNDTIMEO = 30  # 30 second send timeout
    TCP_USER_TIMEOUT = 30000  # 30 sec user timeout

    @classmethod
    def apply_to_socket(cls, sock: socket.socket) -> None:
        """Apply the default socket options to the given socket."""

        # Low-latency optimizations for streaming
        sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, cls.TCP_NODELAY)

        # Connection keepalive settings for long-lived SSE connections
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, cls.SO_KEEPALIVE)

        # Fine-tune keepalive timing (Linux-specific)
        if hasattr(socket, "TCP_KEEPIDLE"):
            sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, cls.TCP_KEEPIDLE)
            sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, cls.TCP_KEEPINTVL)
            sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, cls.TCP_KEEPCNT)

        # Buffer size optimizations for streaming
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, cls.SO_RCVBUF)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, cls.SO_SNDBUF)

        # Linux-specific TCP optimizations
        if hasattr(socket, "TCP_QUICKACK"):
            sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, cls.TCP_QUICKACK)

        if hasattr(socket, "TCP_USER_TIMEOUT"):
            sock.setsockopt(
                socket.SOL_TCP, socket.TCP_USER_TIMEOUT, cls.TCP_USER_TIMEOUT
            )

apply_to_socket(sock) classmethod

Apply the default socket options to the given socket.

Source code in aiperf/clients/http/defaults.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@classmethod
def apply_to_socket(cls, sock: socket.socket) -> None:
    """Apply the default socket options to the given socket."""

    # Low-latency optimizations for streaming
    sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, cls.TCP_NODELAY)

    # Connection keepalive settings for long-lived SSE connections
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, cls.SO_KEEPALIVE)

    # Fine-tune keepalive timing (Linux-specific)
    if hasattr(socket, "TCP_KEEPIDLE"):
        sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, cls.TCP_KEEPIDLE)
        sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, cls.TCP_KEEPINTVL)
        sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, cls.TCP_KEEPCNT)

    # Buffer size optimizations for streaming
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, cls.SO_RCVBUF)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, cls.SO_SNDBUF)

    # Linux-specific TCP optimizations
    if hasattr(socket, "TCP_QUICKACK"):
        sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, cls.TCP_QUICKACK)

    if hasattr(socket, "TCP_USER_TIMEOUT"):
        sock.setsockopt(
            socket.SOL_TCP, socket.TCP_USER_TIMEOUT, cls.TCP_USER_TIMEOUT
        )

aiperf.clients.model_endpoint_info

Model endpoint information.

This module contains the pydantic models that encapsulate the information needed to send requests to an inference server, primarily around the model, endpoint, and additional request payload information.

EndpointInfo

Bases: AIPerfBaseModel

Information about an endpoint.

Source code in aiperf/clients/model_endpoint_info.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class EndpointInfo(AIPerfBaseModel):
    """Information about an endpoint."""

    type: EndpointType = Field(
        default=EndpointType.OPENAI_CHAT_COMPLETIONS,
        description="The type of request payload to use for the endpoint.",
    )
    base_url: str | None = Field(
        default=None,
        description="URL of the endpoint.",
    )
    custom_endpoint: str | None = Field(
        default=None,
        description="Custom endpoint to use for the models.",
    )
    url_params: dict[str, Any] | None = Field(
        default=None, description="Custom URL parameters to use for the endpoint."
    )
    streaming: bool = Field(
        default=False,
        description="Whether the endpoint supports streaming.",
    )
    headers: dict[str, str] | None = Field(
        default=None,
        description="Custom URL headers to use for the endpoint.",
    )
    api_key: str | None = Field(
        default=None,
        description="API key to use for the endpoint.",
    )
    ssl_options: dict[str, Any] | None = Field(
        default=None,
        description="SSL options to use for the endpoint.",
    )
    timeout: float = Field(
        default=EndpointDefaults.TIMEOUT,
        description="The timeout in seconds for each request to the endpoint.",
    )
    extra: dict[str, Any] | None = Field(
        default=None,
        description="Additional inputs to include with every request. "
        "You can repeat this flag for multiple inputs. Inputs should be in an 'input_name:value' format. "
        "Alternatively, a string representing a json formatted dict can be provided.",
    )

    @classmethod
    def from_user_config(cls, user_config: UserConfig) -> "EndpointInfo":
        """Create an HttpEndpointInfo from a UserConfig."""
        return cls(
            type=EndpointType(user_config.endpoint.type),
            custom_endpoint=user_config.endpoint.custom_endpoint,
            streaming=user_config.endpoint.streaming,
            base_url=user_config.endpoint.url,
            headers=user_config.input.headers,
            extra=user_config.input.extra,
            timeout=user_config.endpoint.timeout_seconds,
            api_key=user_config.endpoint.api_key,
        )

from_user_config(user_config) classmethod

Create an HttpEndpointInfo from a UserConfig.

Source code in aiperf/clients/model_endpoint_info.py
108
109
110
111
112
113
114
115
116
117
118
119
120
@classmethod
def from_user_config(cls, user_config: UserConfig) -> "EndpointInfo":
    """Create an HttpEndpointInfo from a UserConfig."""
    return cls(
        type=EndpointType(user_config.endpoint.type),
        custom_endpoint=user_config.endpoint.custom_endpoint,
        streaming=user_config.endpoint.streaming,
        base_url=user_config.endpoint.url,
        headers=user_config.input.headers,
        extra=user_config.input.extra,
        timeout=user_config.endpoint.timeout_seconds,
        api_key=user_config.endpoint.api_key,
    )

ModelEndpointInfo

Bases: AIPerfBaseModel

Information about a model endpoint.

Source code in aiperf/clients/model_endpoint_info.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
class ModelEndpointInfo(AIPerfBaseModel):
    """Information about a model endpoint."""

    models: ModelListInfo = Field(
        ...,
        description="The models to use for the endpoint.",
    )
    endpoint: EndpointInfo = Field(
        ...,
        description="The endpoint to use for the models.",
    )

    @classmethod
    def from_user_config(cls, user_config: UserConfig) -> "ModelEndpointInfo":
        """Create a ModelEndpointInfo from a UserConfig."""
        return cls(
            models=ModelListInfo.from_user_config(user_config),
            endpoint=EndpointInfo.from_user_config(user_config),
        )

    @property
    def url(self) -> str:
        """Get the full URL for the endpoint."""
        url = self.endpoint.base_url.rstrip("/") if self.endpoint.base_url else ""
        if self.endpoint.custom_endpoint:
            url += "/" + self.endpoint.custom_endpoint.lstrip("/")
        elif path := self.endpoint.type.endpoint_path:
            url += "/" + path.lstrip("/")
        return url

    @property
    def primary_model(self) -> ModelInfo:
        """Get the primary model."""
        return self.models.models[0]

    @property
    def primary_model_name(self) -> str:
        """Get the primary model name."""
        return self.primary_model.name

primary_model property

Get the primary model.

primary_model_name property

Get the primary model name.

url property

Get the full URL for the endpoint.

from_user_config(user_config) classmethod

Create a ModelEndpointInfo from a UserConfig.

Source code in aiperf/clients/model_endpoint_info.py
135
136
137
138
139
140
141
@classmethod
def from_user_config(cls, user_config: UserConfig) -> "ModelEndpointInfo":
    """Create a ModelEndpointInfo from a UserConfig."""
    return cls(
        models=ModelListInfo.from_user_config(user_config),
        endpoint=EndpointInfo.from_user_config(user_config),
    )

ModelInfo

Bases: AIPerfBaseModel

Information about a model.

Source code in aiperf/clients/model_endpoint_info.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class ModelInfo(AIPerfBaseModel):
    """Information about a model."""

    name: str = Field(
        ...,
        min_length=1,
        description="The name of the model. This is used to identify the model.",
    )
    version: str | None = Field(
        default=None,
        description="The version of the model.",
    )
    modality: Modality = Field(
        default=Modality.TEXT,
        description="The modality of the model. This is used to determine the type of request payload "
        "to use for the endpoint. If CUSTOM, the model is not supported.",
    )

ModelListInfo

Bases: AIPerfBaseModel

Information about a list of models.

Source code in aiperf/clients/model_endpoint_info.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class ModelListInfo(AIPerfBaseModel):
    """Information about a list of models."""

    models: list[ModelInfo] = Field(
        ...,
        min_length=1,
        description="The models to use for the endpoint.",
    )
    model_selection_strategy: ModelSelectionStrategy = Field(
        ...,
        description="The strategy to use for selecting the model to use for the endpoint.",
    )

    @classmethod
    def from_user_config(cls, user_config: UserConfig) -> "ModelListInfo":
        """Create a ModelListInfo from a UserConfig."""
        return cls(
            models=[
                ModelInfo(name=model) for model in user_config.endpoint.model_names
            ],
            model_selection_strategy=user_config.endpoint.model_selection_strategy,
        )

from_user_config(user_config) classmethod

Create a ModelListInfo from a UserConfig.

Source code in aiperf/clients/model_endpoint_info.py
52
53
54
55
56
57
58
59
60
@classmethod
def from_user_config(cls, user_config: UserConfig) -> "ModelListInfo":
    """Create a ModelListInfo from a UserConfig."""
    return cls(
        models=[
            ModelInfo(name=model) for model in user_config.endpoint.model_names
        ],
        model_selection_strategy=user_config.endpoint.model_selection_strategy,
    )

aiperf.clients.openai.openai_aiohttp

OpenAIClientAioHttp

Bases: AioHttpClientMixin, AIPerfLoggerMixin, ABC

Inference client for OpenAI based requests using aiohttp.

Source code in aiperf/clients/openai/openai_aiohttp.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@InferenceClientFactory.register_all(
    EndpointType.OPENAI_CHAT_COMPLETIONS,
    EndpointType.OPENAI_COMPLETIONS,
    EndpointType.OPENAI_EMBEDDINGS,
    EndpointType.OPENAI_RESPONSES,
)
class OpenAIClientAioHttp(AioHttpClientMixin, AIPerfLoggerMixin, ABC):
    """Inference client for OpenAI based requests using aiohttp."""

    def __init__(self, model_endpoint: ModelEndpointInfo, **kwargs) -> None:
        super().__init__(model_endpoint, **kwargs)
        self.model_endpoint = model_endpoint

    def get_headers(self, model_endpoint: ModelEndpointInfo) -> dict[str, str]:
        """Get the headers for the given endpoint."""

        accept = (
            "text/event-stream"
            if model_endpoint.endpoint.streaming
            else "application/json"
        )

        headers = {
            "User-Agent": "aiperf/1.0",
            "Content-Type": "application/json",
            "Accept": accept,
        }
        if model_endpoint.endpoint.api_key:
            headers["Authorization"] = f"Bearer {model_endpoint.endpoint.api_key}"
        if model_endpoint.endpoint.headers:
            headers.update(model_endpoint.endpoint.headers)
        return headers

    def get_url(self, model_endpoint: ModelEndpointInfo) -> str:
        """Get the URL for the given endpoint."""
        url = model_endpoint.url
        if not url.startswith("http"):
            url = f"http://{url}"
        return url

    async def send_request(
        self,
        model_endpoint: ModelEndpointInfo,
        payload: dict[str, Any],
    ) -> RequestRecord:
        """Send OpenAI request using aiohttp."""

        # capture start time before request is sent in the case of an error
        start_perf_ns = time.perf_counter_ns()
        try:
            self.debug(
                lambda: f"Sending OpenAI request to {model_endpoint.url}, payload: {payload}"
            )

            record = await self.post_request(
                self.get_url(model_endpoint),
                json.dumps(payload),
                self.get_headers(model_endpoint),
            )
            record.request = payload

        except Exception as e:
            record = RequestRecord(
                request=payload,
                start_perf_ns=start_perf_ns,
                end_perf_ns=time.perf_counter_ns(),
                error=ErrorDetails(type=e.__class__.__name__, message=str(e)),
            )
            self.exception(f"Error in OpenAI request: {e.__class__.__name__} {str(e)}")

        return record

get_headers(model_endpoint)

Get the headers for the given endpoint.

Source code in aiperf/clients/openai/openai_aiohttp.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def get_headers(self, model_endpoint: ModelEndpointInfo) -> dict[str, str]:
    """Get the headers for the given endpoint."""

    accept = (
        "text/event-stream"
        if model_endpoint.endpoint.streaming
        else "application/json"
    )

    headers = {
        "User-Agent": "aiperf/1.0",
        "Content-Type": "application/json",
        "Accept": accept,
    }
    if model_endpoint.endpoint.api_key:
        headers["Authorization"] = f"Bearer {model_endpoint.endpoint.api_key}"
    if model_endpoint.endpoint.headers:
        headers.update(model_endpoint.endpoint.headers)
    return headers

get_url(model_endpoint)

Get the URL for the given endpoint.

Source code in aiperf/clients/openai/openai_aiohttp.py
50
51
52
53
54
55
def get_url(self, model_endpoint: ModelEndpointInfo) -> str:
    """Get the URL for the given endpoint."""
    url = model_endpoint.url
    if not url.startswith("http"):
        url = f"http://{url}"
    return url

send_request(model_endpoint, payload) async

Send OpenAI request using aiohttp.

Source code in aiperf/clients/openai/openai_aiohttp.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
async def send_request(
    self,
    model_endpoint: ModelEndpointInfo,
    payload: dict[str, Any],
) -> RequestRecord:
    """Send OpenAI request using aiohttp."""

    # capture start time before request is sent in the case of an error
    start_perf_ns = time.perf_counter_ns()
    try:
        self.debug(
            lambda: f"Sending OpenAI request to {model_endpoint.url}, payload: {payload}"
        )

        record = await self.post_request(
            self.get_url(model_endpoint),
            json.dumps(payload),
            self.get_headers(model_endpoint),
        )
        record.request = payload

    except Exception as e:
        record = RequestRecord(
            request=payload,
            start_perf_ns=start_perf_ns,
            end_perf_ns=time.perf_counter_ns(),
            error=ErrorDetails(type=e.__class__.__name__, message=str(e)),
        )
        self.exception(f"Error in OpenAI request: {e.__class__.__name__} {str(e)}")

    return record

aiperf.clients.openai.openai_chat

OpenAIChatCompletionRequestConverter

Bases: AIPerfLoggerMixin

Request converter for OpenAI chat completion requests.

Source code in aiperf/clients/openai/openai_chat.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@RequestConverterFactory.register(EndpointType.OPENAI_CHAT_COMPLETIONS)
class OpenAIChatCompletionRequestConverter(AIPerfLoggerMixin):
    """Request converter for OpenAI chat completion requests."""

    async def format_payload(
        self,
        model_endpoint: ModelEndpointInfo,
        turn: Turn,
    ) -> dict[str, Any]:
        """Format payload for a chat completion request."""

        messages = self._create_messages(turn)

        payload = {
            "messages": messages,
            "model": turn.model or model_endpoint.primary_model_name,
            "stream": model_endpoint.endpoint.streaming,
        }

        if model_endpoint.endpoint.extra:
            payload.update(model_endpoint.endpoint.extra)

        self.debug(lambda: f"Formatted payload: {payload}")
        return payload

    def _create_messages(self, turn: Turn) -> list[dict[str, Any]]:
        message = {
            "role": turn.role or DEFAULT_ROLE,
            "content": [],
        }
        for text in turn.texts:
            for content in text.contents:
                if not content:
                    continue
                message["content"].append({"type": "text", "text": content})

        for image in turn.images:
            for content in image.contents:
                if not content:
                    continue
                message["content"].append(
                    {"type": "image_url", "image_url": {"url": content}}
                )

        for audio in turn.audios:
            for content in audio.contents:
                if not content:
                    continue
                if "," not in content:
                    raise ValueError(
                        "Audio content must be in the format 'format,b64_audio'."
                    )
                format, b64_audio = content.split(",", 1)
                message["content"].append(
                    {
                        "type": "input_audio",
                        "input_audio": {
                            "data": b64_audio,
                            "format": format,
                        },
                    }
                )

        # Hotfix for Dynamo API which does not yet support a list of messages
        if (
            len(message["content"]) == 1
            and "text" in message["content"][0]
            and len(turn.texts) == 1
        ):
            messages = [
                {
                    "role": message["role"],
                    "name": turn.texts[0].name,
                    "content": message["content"][0].get("text"),
                }
            ]
        else:
            messages = [message]
        return messages

format_payload(model_endpoint, turn) async

Format payload for a chat completion request.

Source code in aiperf/clients/openai/openai_chat.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
async def format_payload(
    self,
    model_endpoint: ModelEndpointInfo,
    turn: Turn,
) -> dict[str, Any]:
    """Format payload for a chat completion request."""

    messages = self._create_messages(turn)

    payload = {
        "messages": messages,
        "model": turn.model or model_endpoint.primary_model_name,
        "stream": model_endpoint.endpoint.streaming,
    }

    if model_endpoint.endpoint.extra:
        payload.update(model_endpoint.endpoint.extra)

    self.debug(lambda: f"Formatted payload: {payload}")
    return payload

aiperf.clients.openai.openai_completions

OpenAICompletionRequestConverter

Bases: AIPerfLoggerMixin

Request converter for OpenAI completion requests.

Source code in aiperf/clients/openai/openai_completions.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@RequestConverterFactory.register(EndpointType.OPENAI_COMPLETIONS)
class OpenAICompletionRequestConverter(AIPerfLoggerMixin):
    """Request converter for OpenAI completion requests."""

    async def format_payload(
        self,
        model_endpoint: ModelEndpointInfo,
        turn: Turn,
    ) -> dict[str, Any]:
        """Format payload for a completion request."""

        prompts = [
            content for text in turn.texts for content in text.contents if content
        ]

        extra = model_endpoint.endpoint.extra or {}

        payload = {
            "prompt": prompts,
            "model": turn.model or model_endpoint.primary_model_name,
            "stream": model_endpoint.endpoint.streaming,
        }

        if extra:
            payload.update(extra)

        self.debug(lambda: f"Formatted payload: {payload}")
        return payload

format_payload(model_endpoint, turn) async

Format payload for a completion request.

Source code in aiperf/clients/openai/openai_completions.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
async def format_payload(
    self,
    model_endpoint: ModelEndpointInfo,
    turn: Turn,
) -> dict[str, Any]:
    """Format payload for a completion request."""

    prompts = [
        content for text in turn.texts for content in text.contents if content
    ]

    extra = model_endpoint.endpoint.extra or {}

    payload = {
        "prompt": prompts,
        "model": turn.model or model_endpoint.primary_model_name,
        "stream": model_endpoint.endpoint.streaming,
    }

    if extra:
        payload.update(extra)

    self.debug(lambda: f"Formatted payload: {payload}")
    return payload

aiperf.clients.openai.openai_embeddings

OpenAIEmbeddingsRequestConverter

Bases: AIPerfLoggerMixin

Request converter for OpenAI embeddings requests.

Source code in aiperf/clients/openai/openai_embeddings.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@RequestConverterFactory.register(EndpointType.OPENAI_EMBEDDINGS)
class OpenAIEmbeddingsRequestConverter(AIPerfLoggerMixin):
    """Request converter for OpenAI embeddings requests."""

    async def format_payload(
        self,
        model_endpoint: ModelEndpointInfo,
        turn: Turn,
    ) -> dict[str, Any]:
        """Format payload for an embeddings request."""

        prompts = [
            content for text in turn.texts for content in text.contents if content
        ]

        extra = model_endpoint.endpoint.extra or {}

        payload = {
            "model": turn.model or model_endpoint.primary_model_name,
            "input": prompts,
        }

        if extra:
            payload.update(extra)

        self.debug(lambda: f"Formatted payload: {payload}")
        return payload

format_payload(model_endpoint, turn) async

Format payload for an embeddings request.

Source code in aiperf/clients/openai/openai_embeddings.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
async def format_payload(
    self,
    model_endpoint: ModelEndpointInfo,
    turn: Turn,
) -> dict[str, Any]:
    """Format payload for an embeddings request."""

    prompts = [
        content for text in turn.texts for content in text.contents if content
    ]

    extra = model_endpoint.endpoint.extra or {}

    payload = {
        "model": turn.model or model_endpoint.primary_model_name,
        "input": prompts,
    }

    if extra:
        payload.update(extra)

    self.debug(lambda: f"Formatted payload: {payload}")
    return payload

aiperf.clients.openai.openai_responses

OpenAIResponsesRequestConverter

Bases: AIPerfLoggerMixin

Request converter for OpenAI Responses requests.

Source code in aiperf/clients/openai/openai_responses.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@RequestConverterFactory.register(EndpointType.OPENAI_RESPONSES)
class OpenAIResponsesRequestConverter(AIPerfLoggerMixin):
    """Request converter for OpenAI Responses requests."""

    async def format_payload(
        self,
        model_endpoint: ModelEndpointInfo,
        turn: Turn,
    ) -> dict[str, Any]:
        """Format payload for a responses request."""

        # TODO: Add support for image and audio inputs.
        prompts = [
            content for text in turn.texts for content in text.contents if content
        ]

        extra = model_endpoint.endpoint.extra or {}

        payload = {
            "input": prompts,
            "model": model_endpoint.primary_model_name,
            # TODO: How do we handle max_output_tokens? Should be provided by OSL logic
            "max_output_tokens": extra.pop("max_output_tokens", None),
            "stream": model_endpoint.endpoint.streaming,
        }

        if extra:
            payload.update(extra)

        self.debug(lambda: f"Formatted payload: {payload}")
        return payload

format_payload(model_endpoint, turn) async

Format payload for a responses request.

Source code in aiperf/clients/openai/openai_responses.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
async def format_payload(
    self,
    model_endpoint: ModelEndpointInfo,
    turn: Turn,
) -> dict[str, Any]:
    """Format payload for a responses request."""

    # TODO: Add support for image and audio inputs.
    prompts = [
        content for text in turn.texts for content in text.contents if content
    ]

    extra = model_endpoint.endpoint.extra or {}

    payload = {
        "input": prompts,
        "model": model_endpoint.primary_model_name,
        # TODO: How do we handle max_output_tokens? Should be provided by OSL logic
        "max_output_tokens": extra.pop("max_output_tokens", None),
        "stream": model_endpoint.endpoint.streaming,
    }

    if extra:
        payload.update(extra)

    self.debug(lambda: f"Formatted payload: {payload}")
    return payload

aiperf.common.aiperf_logger

AIPerfLogger

Logger for AIPerf messages with lazy evaluation support for f-strings.

This logger supports lazy evaluation of f-strings through lambdas to avoid expensive string formatting operations when the log level is not enabled.

It also extends the standard logging module with additional log levels
  • TRACE (TRACE < DEBUG)
  • NOTICE (INFO < NOTICE < WARNING)
  • SUCCESS (WARNING < SUCCESS < ERROR)
Usage

logger = AIPerfLogger("my_logger") logger.debug(lambda: f"Processing {item} with {count} items") logger.info("Simple string message") logger.notice("Notice message") logger.success("Benchmark completed successfully")

Need to pass local variables to the lambda to avoid them going out of scope

logger.debug(lambda i=i: f"Binding loop variable: {i}") logger.exception(f"Direct f-string usage: {e}")

Source code in aiperf/common/aiperf_logger.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
class AIPerfLogger:
    """Logger for AIPerf messages with lazy evaluation support for f-strings.

    This logger supports lazy evaluation of f-strings through lambdas to avoid
    expensive string formatting operations when the log level is not enabled.

    It also extends the standard logging module with additional log levels:
        - TRACE    (TRACE < DEBUG)
        - NOTICE   (INFO < NOTICE < WARNING)
        - SUCCESS  (WARNING < SUCCESS < ERROR)

    Usage:
        logger = AIPerfLogger("my_logger")
        logger.debug(lambda: f"Processing {item} with {count} items")
        logger.info("Simple string message")
        logger.notice("Notice message")
        logger.success("Benchmark completed successfully")
        # Need to pass local variables to the lambda to avoid them going out of scope
        logger.debug(lambda i=i: f"Binding loop variable: {i}")
        logger.exception(f"Direct f-string usage: {e}")
    """

    def __init__(self, logger_name: str):
        self.logger_name = logger_name
        self._logger = logging.getLogger(logger_name)

        # Cache the internal logging module's _log method
        self._internal_log = self._logger._log

        # Forward the internal findCaller method to our custom method
        self._logger.findCaller = self.find_caller

        # Python style method names
        self.is_enabled_for = self._logger.isEnabledFor
        self.set_level = self._logger.setLevel
        self.get_effective_level = self._logger.getEffectiveLevel

        # Legacy logging method compatibility / passthrough
        self.isEnabledFor = self._logger.isEnabledFor
        self.setLevel = self._logger.setLevel
        self.getEffectiveLevel = self._logger.getEffectiveLevel
        self.handlers = self._logger.handlers
        self.addHandler = self._logger.addHandler
        self.removeHandler = self._logger.removeHandler
        self.hasHandlers = self._logger.hasHandlers
        self.root = self._logger.root

    @property
    def is_debug_enabled(self) -> bool:
        return self.is_enabled_for(_DEBUG)

    @property
    def is_trace_enabled(self) -> bool:
        return self.is_enabled_for(_TRACE)

    def _log(self, level: int, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Internal log method that handles lazy evaluation of f-strings."""
        if callable(msg):
            # NOTE: Internal python logging _log method requires a tuple for the args, even if there are no args
            self._internal_log(level, msg(*args), (), **kwargs)
        else:
            self._internal_log(level, msg, args, **kwargs)

    @classmethod
    def is_valid_level(cls, level: int | str) -> bool:
        """Check if the given level is a valid level."""
        if isinstance(level, str):
            return level in [
                "TRACE",
                "DEBUG",
                "INFO",
                "NOTICE",
                "WARNING",
                "SUCCESS",
                "ERROR",
                "CRITICAL",
            ]
        else:
            return level in [
                _TRACE,
                _DEBUG,
                _INFO,
                _NOTICE,
                _WARNING,
                _SUCCESS,
                _ERROR,
                _CRITICAL,
            ]

    @classmethod
    def get_level_number(cls, level: int | str) -> int:
        """Get the numeric level for the given level."""
        if isinstance(level, str):
            return getattr(cls, level.upper())
        else:
            return level

    def find_caller(
        self, stack_info=False, stacklevel=1
    ) -> tuple[str, int, str, str | None]:
        """
        NOTE: This is a modified version of the findCaller method in the logging module,
        in order to allow us to add custom ignored files.

        Find the stack frame of the caller so that we can note the source
        file name, line number and function name.
        """
        f = currentframe()
        # On some versions of IronPython, currentframe() returns None if
        # IronPython isn't run with -X:Frames.
        if f is not None:
            f = f.f_back
        orig_f = f
        while f and stacklevel > 1:
            f = f.f_back
            stacklevel -= 1
        if not f:
            f = orig_f
        rv = "(unknown file)", 0, "(unknown function)", None
        while f and hasattr(f, "f_code"):
            co = f.f_code
            filename = os.path.normcase(co.co_filename)
            # NOTE: The if-statement below was modified to use our own list of ignored files (_ignored_files).
            # This is required to avoid it appearing as all logs are coming from this file.
            if filename in _ignored_files:
                f = f.f_back
                continue
            sinfo = None
            if stack_info:
                sio = io.StringIO()
                sio.write("Stack (most recent call last):\n")
                traceback.print_stack(f, file=sio)
                sinfo = sio.getvalue()
                if sinfo[-1] == "\n":
                    sinfo = sinfo[:-1]
                sio.close()
            rv = (co.co_filename, f.f_lineno, co.co_name, sinfo)
            break
        return rv

    def log(self, level: int, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(level):
            self._log(level, msg, args, **kwargs)

    def trace(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a trace message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_TRACE):
            self._log(_TRACE, msg, *args, **kwargs)

    def debug(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a debug message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_DEBUG):
            self._log(_DEBUG, msg, *args, **kwargs)

    def info(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an info message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_INFO):
            self._log(_INFO, msg, *args, **kwargs)

    def notice(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a notice message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_NOTICE):
            self._log(_NOTICE, msg, *args, **kwargs)

    def warning(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a warning message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_WARNING):
            self._log(_WARNING, msg, *args, **kwargs)

    def success(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a success message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_SUCCESS):
            self._log(_SUCCESS, msg, *args, **kwargs)

    def error(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an error message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_ERROR):
            self._log(_ERROR, msg, *args, **kwargs)

    def exception(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an exception message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_ERROR):
            self._log(_ERROR, msg, *args, exc_info=True, **kwargs)

    def critical(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a critical message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_CRITICAL):
            self._log(_CRITICAL, msg, *args, **kwargs)

    def trace_or_debug(
        self,
        trace_msg: str | Callable[..., str],
        debug_msg: str | Callable[..., str],
    ) -> None:
        """Log different messages depending on the level of the logger.

        This method is used to log a message at the trace level if the trace level is enabled,
        otherwise it will log a debug message. It enables us to use a single method to log
        different messages depending on the level of the logger. Use this method to provide
        full dumps of data when the logger is in trace mode, and a more concise message when
        the logger is in debug mode.

        Example:
        ```python
        self.trace_or_debug(
            lambda: f"Received request: {request}",
            lambda: f"Received request id: {request.id}",
        )
        ```
        """
        if self.is_enabled_for(_TRACE):
            self._log(_TRACE, trace_msg)
        elif self.is_enabled_for(_DEBUG):
            self._log(_DEBUG, debug_msg)

critical(msg, *args, **kwargs)

Log a critical message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
210
211
212
213
def critical(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a critical message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_CRITICAL):
        self._log(_CRITICAL, msg, *args, **kwargs)

debug(msg, *args, **kwargs)

Log a debug message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
175
176
177
178
def debug(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a debug message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_DEBUG):
        self._log(_DEBUG, msg, *args, **kwargs)

error(msg, *args, **kwargs)

Log an error message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
200
201
202
203
def error(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an error message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_ERROR):
        self._log(_ERROR, msg, *args, **kwargs)

exception(msg, *args, **kwargs)

Log an exception message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
205
206
207
208
def exception(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an exception message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_ERROR):
        self._log(_ERROR, msg, *args, exc_info=True, **kwargs)

find_caller(stack_info=False, stacklevel=1)

NOTE: This is a modified version of the findCaller method in the logging module, in order to allow us to add custom ignored files.

Find the stack frame of the caller so that we can note the source file name, line number and function name.

Source code in aiperf/common/aiperf_logger.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def find_caller(
    self, stack_info=False, stacklevel=1
) -> tuple[str, int, str, str | None]:
    """
    NOTE: This is a modified version of the findCaller method in the logging module,
    in order to allow us to add custom ignored files.

    Find the stack frame of the caller so that we can note the source
    file name, line number and function name.
    """
    f = currentframe()
    # On some versions of IronPython, currentframe() returns None if
    # IronPython isn't run with -X:Frames.
    if f is not None:
        f = f.f_back
    orig_f = f
    while f and stacklevel > 1:
        f = f.f_back
        stacklevel -= 1
    if not f:
        f = orig_f
    rv = "(unknown file)", 0, "(unknown function)", None
    while f and hasattr(f, "f_code"):
        co = f.f_code
        filename = os.path.normcase(co.co_filename)
        # NOTE: The if-statement below was modified to use our own list of ignored files (_ignored_files).
        # This is required to avoid it appearing as all logs are coming from this file.
        if filename in _ignored_files:
            f = f.f_back
            continue
        sinfo = None
        if stack_info:
            sio = io.StringIO()
            sio.write("Stack (most recent call last):\n")
            traceback.print_stack(f, file=sio)
            sinfo = sio.getvalue()
            if sinfo[-1] == "\n":
                sinfo = sinfo[:-1]
            sio.close()
        rv = (co.co_filename, f.f_lineno, co.co_name, sinfo)
        break
    return rv

get_level_number(level) classmethod

Get the numeric level for the given level.

Source code in aiperf/common/aiperf_logger.py
114
115
116
117
118
119
120
@classmethod
def get_level_number(cls, level: int | str) -> int:
    """Get the numeric level for the given level."""
    if isinstance(level, str):
        return getattr(cls, level.upper())
    else:
        return level

info(msg, *args, **kwargs)

Log an info message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
180
181
182
183
def info(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an info message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_INFO):
        self._log(_INFO, msg, *args, **kwargs)

is_valid_level(level) classmethod

Check if the given level is a valid level.

Source code in aiperf/common/aiperf_logger.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@classmethod
def is_valid_level(cls, level: int | str) -> bool:
    """Check if the given level is a valid level."""
    if isinstance(level, str):
        return level in [
            "TRACE",
            "DEBUG",
            "INFO",
            "NOTICE",
            "WARNING",
            "SUCCESS",
            "ERROR",
            "CRITICAL",
        ]
    else:
        return level in [
            _TRACE,
            _DEBUG,
            _INFO,
            _NOTICE,
            _WARNING,
            _SUCCESS,
            _ERROR,
            _CRITICAL,
        ]

log(level, msg, *args, **kwargs)

Log a message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
165
166
167
168
def log(self, level: int, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(level):
        self._log(level, msg, args, **kwargs)

notice(msg, *args, **kwargs)

Log a notice message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
185
186
187
188
def notice(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a notice message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_NOTICE):
        self._log(_NOTICE, msg, *args, **kwargs)

success(msg, *args, **kwargs)

Log a success message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
195
196
197
198
def success(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a success message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_SUCCESS):
        self._log(_SUCCESS, msg, *args, **kwargs)

trace(msg, *args, **kwargs)

Log a trace message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
170
171
172
173
def trace(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a trace message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_TRACE):
        self._log(_TRACE, msg, *args, **kwargs)

trace_or_debug(trace_msg, debug_msg)

Log different messages depending on the level of the logger.

This method is used to log a message at the trace level if the trace level is enabled, otherwise it will log a debug message. It enables us to use a single method to log different messages depending on the level of the logger. Use this method to provide full dumps of data when the logger is in trace mode, and a more concise message when the logger is in debug mode.

Example:

self.trace_or_debug(
    lambda: f"Received request: {request}",
    lambda: f"Received request id: {request.id}",
)
Source code in aiperf/common/aiperf_logger.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def trace_or_debug(
    self,
    trace_msg: str | Callable[..., str],
    debug_msg: str | Callable[..., str],
) -> None:
    """Log different messages depending on the level of the logger.

    This method is used to log a message at the trace level if the trace level is enabled,
    otherwise it will log a debug message. It enables us to use a single method to log
    different messages depending on the level of the logger. Use this method to provide
    full dumps of data when the logger is in trace mode, and a more concise message when
    the logger is in debug mode.

    Example:
    ```python
    self.trace_or_debug(
        lambda: f"Received request: {request}",
        lambda: f"Received request id: {request.id}",
    )
    ```
    """
    if self.is_enabled_for(_TRACE):
        self._log(_TRACE, trace_msg)
    elif self.is_enabled_for(_DEBUG):
        self._log(_DEBUG, debug_msg)

warning(msg, *args, **kwargs)

Log a warning message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
190
191
192
193
def warning(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a warning message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_WARNING):
        self._log(_WARNING, msg, *args, **kwargs)

aiperf.common.base_component_service

BaseComponentService

Bases: BaseService

Base class for all Component services.

This class provides a common interface for all Component services in the AIPerf framework such as the Timing Manager, Dataset Manager, etc.

It extends the BaseService by adding heartbeat and registration functionality, as well as publishing the current state of the service to the system controller.

Source code in aiperf/common/base_component_service.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@implements_protocol(ServiceProtocol)
class BaseComponentService(BaseService):
    """Base class for all Component services.

    This class provides a common interface for all Component services in the AIPerf
    framework such as the Timing Manager, Dataset Manager, etc.

    It extends the BaseService by adding heartbeat and registration functionality, as well as
    publishing the current state of the service to the system controller.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
        **kwargs,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            **kwargs,
        )

    @background_task(
        interval=lambda self: self.service_config.heartbeat_interval_seconds,
        immediate=False,
    )
    async def _heartbeat_task(self) -> None:
        """Send a heartbeat notification to the system controller."""
        await self.publish(
            HeartbeatMessage(
                service_id=self.service_id,
                service_type=self.service_type,
                state=self.state,
            )
        )

    @on_start
    async def _register_service_on_start(self) -> None:
        """Register the service with the system controller on startup."""
        self.debug(
            lambda: f"Attempting to register service {self} ({self.service_id}) with system controller"
        )
        result = None
        command_message = RegisterServiceCommand(
            command_id=str(uuid.uuid4()),
            service_id=self.service_id,
            service_type=self.service_type,
            # Target the system controller directly to avoid broadcasting to all services.
            target_service_type=ServiceType.SYSTEM_CONTROLLER,
            state=self.state,
        )
        for _ in range(DEFAULT_MAX_REGISTRATION_ATTEMPTS):
            result = await self.send_command_and_wait_for_response(
                # NOTE: We keep the command id the same each time to ensure that the system controller
                #       can ignore duplicate registration requests.
                command_message,
                timeout=DEFAULT_REGISTRATION_INTERVAL,
            )
            if isinstance(result, CommandResponse):
                self.debug(
                    lambda: f"Service {self.service_id} registered with system controller"
                )
                break
        if isinstance(result, ErrorDetails):
            self.error(
                f"Failed to register service {self} ({self.service_id}): {result}"
            )
            raise self._service_error(
                f"Failed to register service {self} ({self.service_id}): {result}"
            )

    @on_state_change
    async def _on_state_change(
        self, old_state: LifecycleState, new_state: LifecycleState
    ) -> None:
        """Action to take when the service state is set.

        This method will also publish the status message to the status message_type if the
        communications are initialized.
        """
        if self.stop_requested:
            return
        if not self.comms.was_initialized:
            return
        await self.publish(
            StatusMessage(
                service_id=self.service_id,
                service_type=self.service_type,
                state=new_state,
            )
        )

    @on_command(CommandType.SHUTDOWN)
    async def _on_shutdown_command(self, message: CommandMessage) -> None:
        self.debug(f"Received shutdown command: {message}, {self.service_id}")
        try:
            await self.stop()
        except Exception as e:
            self.warning(
                f"Failed to stop service {self} ({self.service_id}) after receiving shutdown command: {e}. Killing."
            )
            await self._kill()
        raise asyncio.CancelledError()

aiperf.common.base_service

BaseService

Bases: CommandHandlerMixin, ABC

Base class for all AIPerf services, providing common functionality for communication, state management, and lifecycle operations. This class inherits from the MessageBusClientMixin, which provides the message bus client functionality.

This class provides the foundation for implementing the various services of the AIPerf system. Some of the abstract methods are implemented here, while others are still required to be implemented by derived classes.

Source code in aiperf/common/base_service.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class BaseService(CommandHandlerMixin, ABC):
    """Base class for all AIPerf services, providing common functionality for
    communication, state management, and lifecycle operations.
    This class inherits from the MessageBusClientMixin, which provides the
    message bus client functionality.

    This class provides the foundation for implementing the various services of the
    AIPerf system. Some of the abstract methods are implemented here, while others
    are still required to be implemented by derived classes.
    """

    service_type: ClassVar[ServiceTypeT]
    """The type of service this class implements. This is set by the ServiceFactory.register decorator."""

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
        **kwargs,
    ) -> None:
        self.service_config = service_config
        self.user_config = user_config
        self.service_id = service_id or f"{self.service_type}_{uuid.uuid4().hex[:8]}"
        super().__init__(
            service_id=self.service_id,
            id=self.service_id,
            service_config=self.service_config,
            user_config=self.user_config,
            **kwargs,
        )
        self.debug(
            lambda: f"__init__ {self.service_type} service (id: {self.service_id})"
        )
        self._set_process_title()

    def _set_process_title(self) -> None:
        try:
            import setproctitle

            setproctitle.setproctitle(f"aiperf {self.service_id}")
        except Exception:
            # setproctitle is not available on all platforms, so we ignore the error
            self.debug("Failed to set process title, ignoring")

    def _service_error(self, message: str) -> ServiceError:
        return ServiceError(
            message=message,
            service_type=self.service_type,
            service_id=self.service_id,
        )

    @on_command(CommandType.SHUTDOWN)
    async def _on_shutdown_command(self, message: CommandMessage) -> None:
        self.debug(f"Received shutdown command from {message.service_id}")
        # Send an acknowledged response back to the sender, because we won't be able to send it after we stop.
        await self.publish(
            CommandAcknowledgedResponse.from_command_message(message, self.service_id)
        )

        try:
            await self.stop()
        except Exception as e:
            self.exception(
                f"Failed to stop service {self} ({self.service_id}) after receiving shutdown command: {e}. Killing."
            )
            await self._kill()

    async def stop(self) -> None:
        """This overrides the base class stop method to handle the case where the service is already stopping.
        In this case, we need to kill the process to be safe."""
        if self.stop_requested:
            self.error(f"Attempted to stop {self} in state {self.state}. Killing.")
            await self._kill()
            return
        await super().stop()

    async def _kill(self) -> None:
        """Kill the lifecycle. This is used when the lifecycle is requested to stop, but is already in a stopping state.
        This is a last resort to ensure that the lifecycle is stopped.
        """
        await self._set_state(LifecycleState.FAILED)
        self.error(lambda: f"Killing {self}")
        self.stop_requested = True
        self.stopped_event.set()
        # TODO: This is a hack to ensure that the process is killed.
        #       We should find a better way to do this.
        os.kill(os.getpid(), signal.SIGKILL)
        raise asyncio.CancelledError(f"Killed {self}")

service_type class-attribute

The type of service this class implements. This is set by the ServiceFactory.register decorator.

stop() async

This overrides the base class stop method to handle the case where the service is already stopping. In this case, we need to kill the process to be safe.

Source code in aiperf/common/base_service.py
88
89
90
91
92
93
94
95
async def stop(self) -> None:
    """This overrides the base class stop method to handle the case where the service is already stopping.
    In this case, we need to kill the process to be safe."""
    if self.stop_requested:
        self.error(f"Attempted to stop {self} in state {self.state}. Killing.")
        await self._kill()
        return
    await super().stop()

aiperf.common.bootstrap

bootstrap_and_run_service(service_class, service_config=None, user_config=None, service_id=None, log_queue=None, **kwargs)

Bootstrap the service and run it.

This function will load the service configuration, create an instance of the service, and run it.

Parameters:

Name Type Description Default
service_class type[ServiceProtocol]

The python class of the service to run. This should be a subclass of BaseService. This should be a type and not an instance.

required
service_config ServiceConfig | None

The service configuration to use. If not provided, the service configuration will be loaded from the environment variables.

None
user_config UserConfig | None

The user configuration to use. If not provided, the user configuration will be loaded from the environment variables.

None
log_queue Queue | None

Optional multiprocessing queue for child process logging. If provided, the child process logging will be set up.

None
kwargs

Additional keyword arguments to pass to the service constructor.

{}
Source code in aiperf/common/bootstrap.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def bootstrap_and_run_service(
    service_class: type[ServiceProtocol],
    service_config: ServiceConfig | None = None,
    user_config: UserConfig | None = None,
    service_id: str | None = None,
    log_queue: "multiprocessing.Queue | None" = None,
    **kwargs,
):
    """Bootstrap the service and run it.

    This function will load the service configuration,
    create an instance of the service, and run it.

    Args:
        service_class: The python class of the service to run. This should be a subclass of
            BaseService. This should be a type and not an instance.
        service_config: The service configuration to use. If not provided, the service
            configuration will be loaded from the environment variables.
        user_config: The user configuration to use. If not provided, the user configuration
            will be loaded from the environment variables.
        log_queue: Optional multiprocessing queue for child process logging. If provided,
            the child process logging will be set up.
        kwargs: Additional keyword arguments to pass to the service constructor.
    """

    # Load the service configuration
    if service_config is None:
        from aiperf.common.config import load_service_config

        service_config = load_service_config()

    # Load the user configuration
    if user_config is None:
        from aiperf.common.config import load_user_config

        # TODO: Add support for loading user config from a file/environment variables
        user_config = load_user_config()

    async def _run_service():
        if service_config.enable_yappi:
            _start_yappi_profiling()

        service = service_class(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            **kwargs,
        )

        from aiperf.common.logging import setup_child_process_logging

        setup_child_process_logging(
            log_queue, service.service_id, service_config, user_config
        )

        if user_config.input.random_seed is not None:
            random.seed(user_config.input.random_seed)
            # Try and set the numpy random seed
            # https://numpy.org/doc/stable/reference/random/index.html#random-quick-start
            with contextlib.suppress(ImportError):
                import numpy as np

                np.random.seed(user_config.input.random_seed)

        try:
            await service.initialize()
            await service.start()
            await service.stopped_event.wait()
        except Exception as e:
            service.exception(f"Unhandled exception in service: {e}")

        if service_config.enable_yappi:
            _stop_yappi_profiling(service.service_id, user_config)

    with contextlib.suppress(asyncio.CancelledError):
        if service_config.enable_uvloop:
            import uvloop

            uvloop.run(_run_service())
        else:
            asyncio.run(_run_service())

aiperf.common.comms.base_comms

BaseCommunication

Bases: AIPerfLifecycleMixin, ABC

Base class for specifying the base communication layer for AIPerf components.

Source code in aiperf/common/comms/base_comms.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@implements_protocol(CommunicationProtocol)
class BaseCommunication(AIPerfLifecycleMixin, ABC):
    """Base class for specifying the base communication layer for AIPerf components."""

    @abstractmethod
    def get_address(self, address_type: CommAddressType) -> str:
        """Get the address for a given address type.

        Args:
            address_type: The type of address to get the address for, or the address itself.

        Returns:
            The address for the given address type, or the address itself if it is a string.
        """

    @abstractmethod
    def create_client(
        self,
        client_type: CommClientType,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
        max_pull_concurrency: int | None = None,
    ) -> CommunicationClientProtocol:
        """Create a communication client for a given client type and address.

        Args:
            client_type: The type of client to create.
            address: The type of address to use when looking up in the communication config, or the address itself.
            bind: Whether to bind or connect the socket.
            socket_ops: Additional socket options to set.
            max_pull_concurrency: The maximum number of concurrent pull requests to allow. (Only used for pull clients)
        """

    def create_pub_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> PubClientProtocol:
        return cast(
            PubClientProtocol,
            self.create_client(CommClientType.PUB, address, bind, socket_ops),
        )

    def create_sub_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> SubClientProtocol:
        return cast(
            SubClientProtocol,
            self.create_client(CommClientType.SUB, address, bind, socket_ops),
        )

    def create_push_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> PushClientProtocol:
        return cast(
            PushClientProtocol,
            self.create_client(CommClientType.PUSH, address, bind, socket_ops),
        )

    def create_pull_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
        max_pull_concurrency: int | None = None,
    ) -> PullClientProtocol:
        return cast(
            PullClientProtocol,
            self.create_client(
                CommClientType.PULL,
                address,
                bind,
                socket_ops,
                max_pull_concurrency=max_pull_concurrency,
            ),
        )

    def create_request_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> RequestClientProtocol:
        return cast(
            RequestClientProtocol,
            self.create_client(CommClientType.REQUEST, address, bind, socket_ops),
        )

    def create_reply_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> ReplyClientProtocol:
        return cast(
            ReplyClientProtocol,
            self.create_client(CommClientType.REPLY, address, bind, socket_ops),
        )

create_client(client_type, address, bind=False, socket_ops=None, max_pull_concurrency=None) abstractmethod

Create a communication client for a given client type and address.

Parameters:

Name Type Description Default
client_type CommClientType

The type of client to create.

required
address CommAddressType

The type of address to use when looking up in the communication config, or the address itself.

required
bind bool

Whether to bind or connect the socket.

False
socket_ops dict | None

Additional socket options to set.

None
max_pull_concurrency int | None

The maximum number of concurrent pull requests to allow. (Only used for pull clients)

None
Source code in aiperf/common/comms/base_comms.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@abstractmethod
def create_client(
    self,
    client_type: CommClientType,
    address: CommAddressType,
    bind: bool = False,
    socket_ops: dict | None = None,
    max_pull_concurrency: int | None = None,
) -> CommunicationClientProtocol:
    """Create a communication client for a given client type and address.

    Args:
        client_type: The type of client to create.
        address: The type of address to use when looking up in the communication config, or the address itself.
        bind: Whether to bind or connect the socket.
        socket_ops: Additional socket options to set.
        max_pull_concurrency: The maximum number of concurrent pull requests to allow. (Only used for pull clients)
    """

get_address(address_type) abstractmethod

Get the address for a given address type.

Parameters:

Name Type Description Default
address_type CommAddressType

The type of address to get the address for, or the address itself.

required

Returns:

Type Description
str

The address for the given address type, or the address itself if it is a string.

Source code in aiperf/common/comms/base_comms.py
26
27
28
29
30
31
32
33
34
35
@abstractmethod
def get_address(self, address_type: CommAddressType) -> str:
    """Get the address for a given address type.

    Args:
        address_type: The type of address to get the address for, or the address itself.

    Returns:
        The address for the given address type, or the address itself if it is a string.
    """

aiperf.common.comms.zmq.dealer_request_client

ZMQDealerRequestClient

Bases: BaseZMQClient, TaskManagerMixin

ZMQ DEALER socket client for asynchronous request-response communication.

The DEALER socket connects to ROUTER sockets and can send requests asynchronously, receiving responses through callbacks or awaitable futures.

ASCII Diagram: ┌──────────────┐ ┌──────────────┐ │ DEALER │───── Request ─────>│ ROUTER │ │ (Client) │ │ (Service) │ │ │<─── Response ──────│ │ └──────────────┘ └──────────────┘

Usage Pattern: - DEALER Clients send requests to ROUTER Services - Responses are routed back to the originating DEALER

DEALER/ROUTER is a Many-to-One communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQDealerRouterProxy for more details.

Source code in aiperf/common/comms/zmq/dealer_request_client.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
@implements_protocol(RequestClientProtocol)
@CommunicationClientFactory.register(CommClientType.REQUEST)
class ZMQDealerRequestClient(BaseZMQClient, TaskManagerMixin):
    """
    ZMQ DEALER socket client for asynchronous request-response communication.

    The DEALER socket connects to ROUTER sockets and can send requests asynchronously,
    receiving responses through callbacks or awaitable futures.

    ASCII Diagram:
    ┌──────────────┐                    ┌──────────────┐
    │    DEALER    │───── Request ─────>│    ROUTER    │
    │   (Client)   │                    │  (Service)   │
    │              │<─── Response ──────│              │
    └──────────────┘                    └──────────────┘

    Usage Pattern:
    - DEALER Clients send requests to ROUTER Services
    - Responses are routed back to the originating DEALER

    DEALER/ROUTER is a Many-to-One communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQDealerRouterProxy` for more details.
    """

    def __init__(
        self,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        **kwargs,
    ) -> None:
        """
        Initialize the ZMQ Dealer (Req) client class.

        Args:
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(zmq.SocketType.DEALER, address, bind, socket_ops, **kwargs)

        self.request_callbacks: dict[
            str, Callable[[Message], Coroutine[Any, Any, None]]
        ] = {}

    @background_task(immediate=True, interval=None)
    async def _request_async_task(self) -> None:
        """Task to handle incoming requests."""
        while not self.stop_requested:
            try:
                message = await self.socket.recv_string()
                self.trace(lambda msg=message: f"Received response: {msg}")
                response_message = Message.from_json(message)

                # Call the callback if it exists
                if response_message.request_id in self.request_callbacks:
                    callback = self.request_callbacks.pop(response_message.request_id)
                    self.execute_async(callback(response_message))

            except zmq.Again:
                self.debug("No data on dealer socket received, yielding to event loop")
                await yield_to_event_loop()
            except Exception as e:
                self.exception(f"Exception receiving responses: {e}")
                await yield_to_event_loop()
            except asyncio.CancelledError:
                self.debug("Dealer request client receiver task cancelled")
                raise  # re-raise the cancelled error

    @on_stop
    async def _stop_remaining_tasks(self) -> None:
        """Wait for all tasks to complete."""
        await self.cancel_all_tasks()

    async def request_async(
        self,
        message: Message,
        callback: Callable[[Message], Coroutine[Any, Any, None]],
    ) -> None:
        """Send a request and be notified when the response is received."""
        await self._check_initialized()

        if not isinstance(message, Message):
            raise TypeError(
                f"message must be an instance of Message, got {type(message).__name__}"
            )

        # Generate request ID if not provided so that responses can be matched
        if not message.request_id:
            message.request_id = str(uuid.uuid4())

        self.request_callbacks[message.request_id] = callback

        request_json = message.model_dump_json()
        self.trace(lambda msg=request_json: f"Sending request: {msg}")

        try:
            await self.socket.send_string(request_json)

        except Exception as e:
            raise CommunicationError(
                f"Exception sending request: {e.__class__.__qualname__} {e}",
            ) from e

    async def request(
        self,
        message: Message,
        timeout: float = DEFAULT_COMMS_REQUEST_TIMEOUT,
    ) -> Message:
        """Send a request and wait for a response up to timeout seconds.

        Args:
            message (Message): The request message to send.
            timeout (float): Maximum time to wait for a response in seconds.

        Returns:
            Message: The response message received.

        Raises:
            CommunicationError: if the request fails, or
            asyncio.TimeoutError: if the response is not received in time.
        """
        future = asyncio.Future[Message]()

        async def callback(response_message: Message) -> None:
            future.set_result(response_message)

        await self.request_async(message, callback)
        return await asyncio.wait_for(future, timeout=timeout)

__init__(address, bind, socket_ops=None, **kwargs)

Initialize the ZMQ Dealer (Req) client class.

Parameters:

Name Type Description Default
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/dealer_request_client.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
    **kwargs,
) -> None:
    """
    Initialize the ZMQ Dealer (Req) client class.

    Args:
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(zmq.SocketType.DEALER, address, bind, socket_ops, **kwargs)

    self.request_callbacks: dict[
        str, Callable[[Message], Coroutine[Any, Any, None]]
    ] = {}

request(message, timeout=DEFAULT_COMMS_REQUEST_TIMEOUT) async

Send a request and wait for a response up to timeout seconds.

Parameters:

Name Type Description Default
message Message

The request message to send.

required
timeout float

Maximum time to wait for a response in seconds.

DEFAULT_COMMS_REQUEST_TIMEOUT

Returns:

Name Type Description
Message Message

The response message received.

Raises:

Type Description
CommunicationError

if the request fails, or

TimeoutError

if the response is not received in time.

Source code in aiperf/common/comms/zmq/dealer_request_client.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
async def request(
    self,
    message: Message,
    timeout: float = DEFAULT_COMMS_REQUEST_TIMEOUT,
) -> Message:
    """Send a request and wait for a response up to timeout seconds.

    Args:
        message (Message): The request message to send.
        timeout (float): Maximum time to wait for a response in seconds.

    Returns:
        Message: The response message received.

    Raises:
        CommunicationError: if the request fails, or
        asyncio.TimeoutError: if the response is not received in time.
    """
    future = asyncio.Future[Message]()

    async def callback(response_message: Message) -> None:
        future.set_result(response_message)

    await self.request_async(message, callback)
    return await asyncio.wait_for(future, timeout=timeout)

request_async(message, callback) async

Send a request and be notified when the response is received.

Source code in aiperf/common/comms/zmq/dealer_request_client.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
async def request_async(
    self,
    message: Message,
    callback: Callable[[Message], Coroutine[Any, Any, None]],
) -> None:
    """Send a request and be notified when the response is received."""
    await self._check_initialized()

    if not isinstance(message, Message):
        raise TypeError(
            f"message must be an instance of Message, got {type(message).__name__}"
        )

    # Generate request ID if not provided so that responses can be matched
    if not message.request_id:
        message.request_id = str(uuid.uuid4())

    self.request_callbacks[message.request_id] = callback

    request_json = message.model_dump_json()
    self.trace(lambda msg=request_json: f"Sending request: {msg}")

    try:
        await self.socket.send_string(request_json)

    except Exception as e:
        raise CommunicationError(
            f"Exception sending request: {e.__class__.__qualname__} {e}",
        ) from e

aiperf.common.comms.zmq.pub_client

ZMQPubClient

Bases: BaseZMQClient

The PUB socket broadcasts messages to all connected SUB sockets that have subscribed to the message topic/type.

ASCII Diagram: ┌──────────────┐ ┌──────────────┐ │ PUB │───>│ │ │ (Publisher) │ │ │ └──────────────┘ │ SUB │ ┌──────────────┐ │ (Subscriber) │ │ PUB │───>│ │ │ (Publisher) │ │ │ └──────────────┘ └──────────────┘ OR ┌──────────────┐ ┌──────────────┐ │ │───>│ SUB │ │ │ │ (Subscriber) │ │ PUB │ └──────────────┘ │ (Publisher) │ ┌──────────────┐ │ │───>│ SUB │ │ │ │ (Subscriber) │ └──────────────┘ └──────────────┘

Usage Pattern: - Single PUB socket broadcasts messages to all subscribers (One-to-Many) OR - Multiple PUB sockets broadcast messages to a single SUB socket (Many-to-One)

  • SUB sockets filter messages by topic/message_type
  • Fire-and-forget messaging (no acknowledgments)

PUB/SUB is a One-to-Many communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQXPubXSubProxy for more details.

Source code in aiperf/common/comms/zmq/pub_client.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@implements_protocol(PubClientProtocol)
@CommunicationClientFactory.register(CommClientType.PUB)
class ZMQPubClient(BaseZMQClient):
    """
    The PUB socket broadcasts messages to all connected SUB sockets that have
    subscribed to the message topic/type.

    ASCII Diagram:
    ┌──────────────┐    ┌──────────────┐
    │     PUB      │───>│              │
    │ (Publisher)  │    │              │
    └──────────────┘    │     SUB      │
    ┌──────────────┐    │ (Subscriber) │
    │     PUB      │───>│              │
    │ (Publisher)  │    │              │
    └──────────────┘    └──────────────┘
    OR
    ┌──────────────┐    ┌──────────────┐
    │              │───>│     SUB      │
    │              │    │ (Subscriber) │
    │     PUB      │    └──────────────┘
    │ (Publisher)  │    ┌──────────────┐
    │              │───>│     SUB      │
    │              │    │ (Subscriber) │
    └──────────────┘    └──────────────┘

    Usage Pattern:
    - Single PUB socket broadcasts messages to all subscribers (One-to-Many)
    OR
    - Multiple PUB sockets broadcast messages to a single SUB socket (Many-to-One)

    - SUB sockets filter messages by topic/message_type
    - Fire-and-forget messaging (no acknowledgments)

    PUB/SUB is a One-to-Many communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQXPubXSubProxy` for more details.
    """

    def __init__(
        self,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        **kwargs,
    ) -> None:
        """
        Initialize the ZMQ Publisher client class.

        Args:
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(zmq.SocketType.PUB, address, bind, socket_ops, **kwargs)

    async def publish(self, message: Message) -> None:
        """Publish a message. The topic will be set automatically based on the message type.

        Args:
            message: Message to publish (must be a Message object)
        """
        await self._check_initialized()

        try:
            topic = self._determine_topic(message)
            message_json = message.model_dump_json()
            # Publish message
            self.trace(lambda: f"Publishing message {topic=} {message_json=}")
            await self.socket.send_multipart([topic.encode(), message_json.encode()])

        except (asyncio.CancelledError, zmq.ContextTerminated):
            self.debug(
                lambda: f"Pub client {self.client_id} cancelled or context terminated"
            )
            return

        except Exception as e:
            raise CommunicationError(
                f"Failed to publish message {message.message_type}: {e}",
            ) from e

    def _determine_topic(self, message: Message) -> str:
        """Determine the topic based on the message."""
        # For targeted messages such as commands, we can set the topic to a specific service by id or type
        # Note that target_service_id always takes precedence over target_service_type

        # NOTE: Keep in mind that subscriptions in ZMQ are prefix based wildcards, so the unique portion has to come first.
        if isinstance(message, TargetedServiceMessage):
            if message.target_service_id:
                return f"{message.message_type}{TOPIC_DELIMITER}{message.target_service_id}{TOPIC_END}"
            if message.target_service_type:
                return f"{message.message_type}{TOPIC_DELIMITER}{message.target_service_type}{TOPIC_END}"
        return f"{message.message_type}{TOPIC_END}"

__init__(address, bind, socket_ops=None, **kwargs)

Initialize the ZMQ Publisher client class.

Parameters:

Name Type Description Default
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/pub_client.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def __init__(
    self,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
    **kwargs,
) -> None:
    """
    Initialize the ZMQ Publisher client class.

    Args:
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(zmq.SocketType.PUB, address, bind, socket_ops, **kwargs)

publish(message) async

Publish a message. The topic will be set automatically based on the message type.

Parameters:

Name Type Description Default
message Message

Message to publish (must be a Message object)

required
Source code in aiperf/common/comms/zmq/pub_client.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
async def publish(self, message: Message) -> None:
    """Publish a message. The topic will be set automatically based on the message type.

    Args:
        message: Message to publish (must be a Message object)
    """
    await self._check_initialized()

    try:
        topic = self._determine_topic(message)
        message_json = message.model_dump_json()
        # Publish message
        self.trace(lambda: f"Publishing message {topic=} {message_json=}")
        await self.socket.send_multipart([topic.encode(), message_json.encode()])

    except (asyncio.CancelledError, zmq.ContextTerminated):
        self.debug(
            lambda: f"Pub client {self.client_id} cancelled or context terminated"
        )
        return

    except Exception as e:
        raise CommunicationError(
            f"Failed to publish message {message.message_type}: {e}",
        ) from e

aiperf.common.comms.zmq.pull_client

ZMQPullClient

Bases: BaseZMQClient

ZMQ PULL socket client for receiving work from PUSH sockets.

The PULL socket receives messages from PUSH sockets in a pipeline pattern, distributing work fairly among multiple PULL workers.

ASCII Diagram: ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ PUSH │ │ PULL │ │ PULL │ │ (Producer) │ │ (Worker 1) │ │ (Worker 2) │ │ │ └─────────────┘ └─────────────┘ │ Tasks: │ ▲ ▲ │ - Task A │─────────────┘ │ │ - Task B │───────────────────────────────────┘ │ - Task C │─────────────┐ │ - Task D │ ▼ └─────────────┘ ┌─────────────┐ │ PULL │ │ (Worker N) │ └─────────────┘

Usage Pattern: - PULL receives work from multiple PUSH producers - Work is fairly distributed among PULL workers - Pipeline pattern for distributed processing - Each message is delivered to exactly one PULL socket

PULL/PUSH is a One-to-Many communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQPushPullProxy for more details.

Source code in aiperf/common/comms/zmq/pull_client.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
@implements_protocol(PullClientProtocol)
@CommunicationClientFactory.register(CommClientType.PULL)
class ZMQPullClient(BaseZMQClient):
    """
    ZMQ PULL socket client for receiving work from PUSH sockets.

    The PULL socket receives messages from PUSH sockets in a pipeline pattern,
    distributing work fairly among multiple PULL workers.

    ASCII Diagram:
    ┌─────────────┐      ┌─────────────┐      ┌─────────────┐
    │    PUSH     │      │    PULL     │      │    PULL     │
    │ (Producer)  │      │ (Worker 1)  │      │ (Worker 2)  │
    │             │      └─────────────┘      └─────────────┘
    │   Tasks:    │             ▲                     ▲
    │   - Task A  │─────────────┘                     │
    │   - Task B  │───────────────────────────────────┘
    │   - Task C  │─────────────┐
    │   - Task D  │             ▼
    └─────────────┘      ┌─────────────┐
                         │    PULL     │
                         │ (Worker N)  │
                         └─────────────┘

    Usage Pattern:
    - PULL receives work from multiple PUSH producers
    - Work is fairly distributed among PULL workers
    - Pipeline pattern for distributed processing
    - Each message is delivered to exactly one PULL socket

    PULL/PUSH is a One-to-Many communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQPushPullProxy` for more details.
    """

    def __init__(
        self,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        max_pull_concurrency: int | None = None,
        **kwargs,
    ) -> None:
        """
        Initialize the ZMQ Puller class.

        Args:
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
            max_pull_concurrency (int, optional): The maximum number of concurrent requests to allow.
        """
        super().__init__(zmq.SocketType.PULL, address, bind, socket_ops, **kwargs)
        self._pull_callbacks: dict[
            MessageTypeT, Callable[[Message], Coroutine[Any, Any, None]]
        ] = {}

        if max_pull_concurrency is not None:
            self.semaphore = asyncio.Semaphore(value=max_pull_concurrency)
        else:
            self.semaphore = asyncio.Semaphore(
                value=int(os.getenv("AIPERF_WORKER_CONCURRENT_REQUESTS", 500))
            )

    @background_task(immediate=True, interval=None)
    async def _pull_receiver(self) -> None:
        """Background task for receiving data from the pull socket.

        This method is a coroutine that will run indefinitely until the client is
        shutdown. It will wait for messages from the socket and handle them.
        """
        while not self.stop_requested:
            try:
                # acquire the semaphore to limit the number of concurrent requests
                # NOTE: This MUST be done BEFORE calling recv_string() to allow the zmq push/pull
                # logic to properly load balance the requests.
                await self.semaphore.acquire()

                message_json = await self.socket.recv_string()
                self.trace(
                    lambda msg=message_json: f"Received message from pull socket: {msg}"
                )
                self.execute_async(self._process_message(message_json))

            except zmq.Again:
                self.debug("Pull client receiver task timed out")
                self.semaphore.release()  # release the semaphore as it was not used
                await yield_to_event_loop()
            except Exception as e:
                self.exception(f"Exception receiving data from pull socket: {e}")
                self.semaphore.release()  # release the semaphore as it was not used
                await yield_to_event_loop()
            except (asyncio.CancelledError, zmq.ContextTerminated):
                self.debug("Pull client receiver task cancelled")
                self.semaphore.release()  # release the semaphore as it was not used
                break

    @on_stop
    async def _stop(self) -> None:
        """Wait for all tasks to complete."""
        await self.cancel_all_tasks()

    async def _process_message(self, message_json: str) -> None:
        """Process a message from the pull socket.

        This method is called by the background task when a message is received from
        the pull socket. It will deserialize the message and call the appropriate
        callback function.
        """
        try:
            message = Message.from_json(message_json)

            # Call callbacks with Message object
            if message.message_type in self._pull_callbacks:
                await self._pull_callbacks[message.message_type](message)
            else:
                self.warning(
                    lambda message_type=message.message_type: f"Pull message received for message type {message_type} without callback"
                )
        finally:
            # always release the semaphore to allow receiving more messages
            self.semaphore.release()

    def register_pull_callback(
        self,
        message_type: MessageTypeT,
        callback: Callable[[Message], Coroutine[Any, Any, None]],
    ) -> None:
        """Register a ZMQ Pull data callback for a given message type.

        Note that only one callback can be registered for a given message type.

        Args:
            message_type: The message type to register the callback for.
            callback: The function to call when data is received.
        Raises:
            CommunicationError: If the client is not initialized
        """
        # Register callback
        if message_type not in self._pull_callbacks:
            self._pull_callbacks[message_type] = callback
        else:
            raise ValueError(
                f"Callback already registered for message type {message_type}"
            )

__init__(address, bind, socket_ops=None, max_pull_concurrency=None, **kwargs)

Initialize the ZMQ Puller class.

Parameters:

Name Type Description Default
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
max_pull_concurrency int

The maximum number of concurrent requests to allow.

None
Source code in aiperf/common/comms/zmq/pull_client.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
    self,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
    max_pull_concurrency: int | None = None,
    **kwargs,
) -> None:
    """
    Initialize the ZMQ Puller class.

    Args:
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
        max_pull_concurrency (int, optional): The maximum number of concurrent requests to allow.
    """
    super().__init__(zmq.SocketType.PULL, address, bind, socket_ops, **kwargs)
    self._pull_callbacks: dict[
        MessageTypeT, Callable[[Message], Coroutine[Any, Any, None]]
    ] = {}

    if max_pull_concurrency is not None:
        self.semaphore = asyncio.Semaphore(value=max_pull_concurrency)
    else:
        self.semaphore = asyncio.Semaphore(
            value=int(os.getenv("AIPERF_WORKER_CONCURRENT_REQUESTS", 500))
        )

register_pull_callback(message_type, callback)

Register a ZMQ Pull data callback for a given message type.

Note that only one callback can be registered for a given message type.

Parameters:

Name Type Description Default
message_type MessageTypeT

The message type to register the callback for.

required
callback Callable[[Message], Coroutine[Any, Any, None]]

The function to call when data is received.

required

Raises: CommunicationError: If the client is not initialized

Source code in aiperf/common/comms/zmq/pull_client.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def register_pull_callback(
    self,
    message_type: MessageTypeT,
    callback: Callable[[Message], Coroutine[Any, Any, None]],
) -> None:
    """Register a ZMQ Pull data callback for a given message type.

    Note that only one callback can be registered for a given message type.

    Args:
        message_type: The message type to register the callback for.
        callback: The function to call when data is received.
    Raises:
        CommunicationError: If the client is not initialized
    """
    # Register callback
    if message_type not in self._pull_callbacks:
        self._pull_callbacks[message_type] = callback
    else:
        raise ValueError(
            f"Callback already registered for message type {message_type}"
        )

aiperf.common.comms.zmq.push_client

MAX_PUSH_RETRIES = 2 module-attribute

Maximum number of retries for pushing a message.

RETRY_DELAY_INTERVAL_SEC = 0.1 module-attribute

The interval to wait before retrying to push a message.

ZMQPushClient

Bases: BaseZMQClient

ZMQ PUSH socket client for sending work to PULL sockets.

The PUSH socket sends messages to PULL sockets in a pipeline pattern, distributing work fairly among available PULL workers.

ASCII Diagram: ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ PUSH │ │ PULL │ │ PULL │ │ (Producer) │ │ (Worker 1) │ │ (Worker 2) │ │ │ └─────────────┘ └─────────────┘ │ Tasks: │ ▲ ▲ │ - Task A │─────────────┘ │ │ - Task B │───────────────────────────────────┘ │ - Task C │─────────────┐ │ - Task D │ ▼ └─────────────┘ ┌─────────────┐ │ PULL │ │ (Worker 3) │ └─────────────┘

Usage Pattern: - Round-robin distribution of work tasks (One-to-Many) - Each message delivered to exactly one worker - Pipeline pattern for distributed processing - Automatic load balancing across available workers

PUSH/PULL is a One-to-Many communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQPushPullProxy for more details.

Source code in aiperf/common/comms/zmq/push_client.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@implements_protocol(PushClientProtocol)
@CommunicationClientFactory.register(CommClientType.PUSH)
class ZMQPushClient(BaseZMQClient):
    """
    ZMQ PUSH socket client for sending work to PULL sockets.

    The PUSH socket sends messages to PULL sockets in a pipeline pattern,
    distributing work fairly among available PULL workers.

    ASCII Diagram:
    ┌─────────────┐      ┌─────────────┐      ┌─────────────┐
    │    PUSH     │      │    PULL     │      │    PULL     │
    │ (Producer)  │      │ (Worker 1)  │      │ (Worker 2)  │
    │             │      └─────────────┘      └─────────────┘
    │   Tasks:    │             ▲                     ▲
    │   - Task A  │─────────────┘                     │
    │   - Task B  │───────────────────────────────────┘
    │   - Task C  │─────────────┐
    │   - Task D  │             ▼
    └─────────────┘      ┌─────────────┐
                         │    PULL     │
                         │ (Worker 3)  │
                         └─────────────┘

    Usage Pattern:
    - Round-robin distribution of work tasks (One-to-Many)
    - Each message delivered to exactly one worker
    - Pipeline pattern for distributed processing
    - Automatic load balancing across available workers

    PUSH/PULL is a One-to-Many communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQPushPullProxy` for more details.
    """

    def __init__(
        self,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        **kwargs,
    ) -> None:
        """
        Initialize the ZMQ Push client class.

        Args:
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(zmq.SocketType.PUSH, address, bind, socket_ops, **kwargs)

    async def _push_message(
        self,
        message: Message,
        retry_count: int = 0,
        max_retries: int = MAX_PUSH_RETRIES,
    ) -> None:
        """Push a message to the socket. Will retry up to max_retries times.

        Args:
            message: Message to be sent must be a Message object
            retry_count: Current retry count
            max_retries: Maximum number of times to retry pushing the message
        """
        try:
            data_json = message.model_dump_json()
            await self.socket.send_string(data_json)
            self.trace(lambda msg=data_json: f"Pushed json data: {msg}")
        except (asyncio.CancelledError, zmq.ContextTerminated):
            self.debug("Push client cancelled or context terminated")
            return
        except zmq.Again as e:
            self.debug("Push client timed out")
            if retry_count >= max_retries:
                raise CommunicationError(
                    f"Failed to push data after {retry_count} retries: {e}",
                ) from e

            await asyncio.sleep(RETRY_DELAY_INTERVAL_SEC)
            return await self._push_message(message, retry_count + 1, max_retries)
        except Exception as e:
            raise CommunicationError(f"Failed to push data: {e}") from e

    async def push(self, message: Message) -> None:
        """Push data to a target. The message will be routed automatically
        based on the message type.

        Args:
            message: Message to be sent must be a Message object
        """
        await self._check_initialized()

        await self._push_message(message)

__init__(address, bind, socket_ops=None, **kwargs)

Initialize the ZMQ Push client class.

Parameters:

Name Type Description Default
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/push_client.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(
    self,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
    **kwargs,
) -> None:
    """
    Initialize the ZMQ Push client class.

    Args:
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(zmq.SocketType.PUSH, address, bind, socket_ops, **kwargs)

push(message) async

Push data to a target. The message will be routed automatically based on the message type.

Parameters:

Name Type Description Default
message Message

Message to be sent must be a Message object

required
Source code in aiperf/common/comms/zmq/push_client.py
106
107
108
109
110
111
112
113
114
115
async def push(self, message: Message) -> None:
    """Push data to a target. The message will be routed automatically
    based on the message type.

    Args:
        message: Message to be sent must be a Message object
    """
    await self._check_initialized()

    await self._push_message(message)

aiperf.common.comms.zmq.router_reply_client

ZMQRouterReplyClient

Bases: BaseZMQClient

ZMQ ROUTER socket client for handling requests from DEALER clients.

The ROUTER socket receives requests from DEALER clients and sends responses back to the originating DEALER client using routing envelopes.

ASCII Diagram: ┌──────────────┐ ┌──────────────┐ │ DEALER │───── Request ─────>│ │ │ (Client) │<──── Response ─────│ │ └──────────────┘ │ │ ┌──────────────┐ │ ROUTER │ │ DEALER │───── Request ─────>│ (Service) │ │ (Client) │<──── Response ─────│ │ └──────────────┘ │ │ ┌──────────────┐ │ │ │ DEALER │───── Request ─────>│ │ │ (Client) │<──── Response ─────│ │ └──────────────┘ └──────────────┘

Usage Pattern: - ROUTER handles requests from multiple DEALER clients - Maintains routing envelopes to send responses back - Many-to-one request handling pattern - Supports concurrent request processing

ROUTER/DEALER is a Many-to-One communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQDealerRouterProxy for more details.

Source code in aiperf/common/comms/zmq/router_reply_client.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
@implements_protocol(ReplyClientProtocol)
@CommunicationClientFactory.register(CommClientType.REPLY)
class ZMQRouterReplyClient(BaseZMQClient):
    """
    ZMQ ROUTER socket client for handling requests from DEALER clients.

    The ROUTER socket receives requests from DEALER clients and sends responses
    back to the originating DEALER client using routing envelopes.

    ASCII Diagram:
    ┌──────────────┐                    ┌──────────────┐
    │    DEALER    │───── Request ─────>│              │
    │   (Client)   │<──── Response ─────│              │
    └──────────────┘                    │              │
    ┌──────────────┐                    │    ROUTER    │
    │    DEALER    │───── Request ─────>│  (Service)   │
    │   (Client)   │<──── Response ─────│              │
    └──────────────┘                    │              │
    ┌──────────────┐                    │              │
    │    DEALER    │───── Request ─────>│              │
    │   (Client)   │<──── Response ─────│              │
    └──────────────┘                    └──────────────┘

    Usage Pattern:
    - ROUTER handles requests from multiple DEALER clients
    - Maintains routing envelopes to send responses back
    - Many-to-one request handling pattern
    - Supports concurrent request processing

    ROUTER/DEALER is a Many-to-One communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQDealerRouterProxy` for more details.
    """

    def __init__(
        self,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        **kwargs,
    ) -> None:
        """
        Initialize the ZMQ Router (Rep) client class.

        Args:
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(zmq.SocketType.ROUTER, address, bind, socket_ops, **kwargs)

        self._request_handlers: dict[
            MessageTypeT,
            tuple[str, Callable[[Message], Coroutine[Any, Any, Message | None]]],
        ] = {}
        self._response_futures: dict[str, asyncio.Future[Message | None]] = {}

    @on_stop
    async def _clear_request_handlers(self) -> None:
        self._request_handlers.clear()

    def register_request_handler(
        self,
        service_id: str,
        message_type: MessageTypeT,
        handler: Callable[[Message], Coroutine[Any, Any, Message | None]],
    ) -> None:
        """Register a request handler. Anytime a request is received that matches the
        message type, the handler will be called. The handler should return a response
        message. If the handler returns None, the request will be ignored.

        Note that there is a limit of 1 to 1 mapping between message type and handler.

        Args:
            service_id: The service ID to register the handler for
            message_type: The message type to register the handler for
            handler: The handler to register
        """
        if message_type in self._request_handlers:
            raise ValueError(
                f"Handler already registered for message type {message_type}"
            )

        self.debug(
            lambda service_id=service_id,
            type=message_type: f"Registering request handler for {service_id} with message type {type}"
        )
        self._request_handlers[message_type] = (service_id, handler)

    async def _handle_request(self, request_id: str, request: Message) -> None:
        """Handle a request.

        This method will:
        - Parse the request JSON to create a Message object
        - Call the handler for the message type
        - Set the response future
        """
        message_type = request.message_type

        try:
            _, handler = self._request_handlers[message_type]
            response = await handler(request)

        except Exception as e:
            self.exception(f"Exception calling handler for {message_type}: {e}")
            response = ErrorMessage(
                request_id=request_id,
                error=ErrorDetails.from_exception(e),
            )

        try:
            self._response_futures[request_id].set_result(response)
        except Exception as e:
            self.exception(
                f"Exception setting response future for request {request_id}: {e}"
            )

    async def _wait_for_response(
        self, request_id: str, routing_envelope: tuple[bytes, ...]
    ) -> None:
        """Wait for a response to a request.

        This method will wait for the response future to be set and then send the response
        back to the client.
        """
        try:
            # Wait for the response asynchronously.
            response = await self._response_futures[request_id]

            if response is None:
                self.warning(
                    lambda req_id=request_id: f"Got None as response for request {req_id}"
                )
                response = ErrorMessage(
                    request_id=request_id,
                    error=ErrorDetails(
                        type="NO_RESPONSE",
                        message="No response was generated for the request.",
                    ),
                )

            self._response_futures.pop(request_id, None)

            # Send the response back to the client.
            await self.socket.send_multipart(
                [*routing_envelope, response.model_dump_json().encode()]
            )
        except Exception as e:
            self.exception(
                f"Exception waiting for response for request {request_id}: {e}"
            )

    @background_task(immediate=True, interval=None)
    async def _rep_router_receiver(self) -> None:
        """Background task for receiving requests and sending responses.

        This method is a coroutine that will run indefinitely until the client is
        shutdown. It will wait for requests from the socket and send responses in
        an asynchronous manner.
        """
        self.debug("Router reply client background task initialized")

        while not self.stop_requested:
            try:
                # Receive request
                try:
                    data = await self.socket.recv_multipart()
                    self.trace(lambda msg=data: f"Received request: {msg}")

                    request = Message.from_json(data[-1])
                    if not request.request_id:
                        self.exception(f"Request ID is missing from request: {data}")
                        continue

                    routing_envelope: tuple[bytes, ...] = (
                        tuple(data[:-1])
                        if len(data) > 1
                        else (request.request_id.encode(),)
                    )
                except zmq.Again:
                    # This means we timed out waiting for a request.
                    # We can continue to the next iteration of the loop.
                    self.debug("Router reply client receiver task timed out")
                    await yield_to_event_loop()
                    continue

                # Create a new response future for this request that will be resolved
                # when the handler returns a response.
                self._response_futures[request.request_id] = asyncio.Future()
                # Handle the request in a new task.
                self.execute_async(self._handle_request(request.request_id, request))
                self.execute_async(
                    self._wait_for_response(request.request_id, routing_envelope)
                )

            except Exception as e:
                self.exception(f"Exception receiving request: {e}")
                await yield_to_event_loop()
            except asyncio.CancelledError:
                self.debug("Router reply client receiver task cancelled")
                break

__init__(address, bind, socket_ops=None, **kwargs)

Initialize the ZMQ Router (Rep) client class.

Parameters:

Name Type Description Default
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/router_reply_client.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
    self,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
    **kwargs,
) -> None:
    """
    Initialize the ZMQ Router (Rep) client class.

    Args:
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(zmq.SocketType.ROUTER, address, bind, socket_ops, **kwargs)

    self._request_handlers: dict[
        MessageTypeT,
        tuple[str, Callable[[Message], Coroutine[Any, Any, Message | None]]],
    ] = {}
    self._response_futures: dict[str, asyncio.Future[Message | None]] = {}

register_request_handler(service_id, message_type, handler)

Register a request handler. Anytime a request is received that matches the message type, the handler will be called. The handler should return a response message. If the handler returns None, the request will be ignored.

Note that there is a limit of 1 to 1 mapping between message type and handler.

Parameters:

Name Type Description Default
service_id str

The service ID to register the handler for

required
message_type MessageTypeT

The message type to register the handler for

required
handler Callable[[Message], Coroutine[Any, Any, Message | None]]

The handler to register

required
Source code in aiperf/common/comms/zmq/router_reply_client.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def register_request_handler(
    self,
    service_id: str,
    message_type: MessageTypeT,
    handler: Callable[[Message], Coroutine[Any, Any, Message | None]],
) -> None:
    """Register a request handler. Anytime a request is received that matches the
    message type, the handler will be called. The handler should return a response
    message. If the handler returns None, the request will be ignored.

    Note that there is a limit of 1 to 1 mapping between message type and handler.

    Args:
        service_id: The service ID to register the handler for
        message_type: The message type to register the handler for
        handler: The handler to register
    """
    if message_type in self._request_handlers:
        raise ValueError(
            f"Handler already registered for message type {message_type}"
        )

    self.debug(
        lambda service_id=service_id,
        type=message_type: f"Registering request handler for {service_id} with message type {type}"
    )
    self._request_handlers[message_type] = (service_id, handler)

aiperf.common.comms.zmq.sub_client

ZMQSubClient

Bases: BaseZMQClient

ZMQ SUB socket client for subscribing to messages from PUB sockets. One-to-Many or Many-to-One communication pattern.

ASCII Diagram: ┌──────────────┐ ┌──────────────┐ │ PUB │───>│ │ │ (Publisher) │ │ │ └──────────────┘ │ SUB │ ┌──────────────┐ │ (Subscriber) │ │ PUB │───>│ │ │ (Publisher) │ │ │ └──────────────┘ └──────────────┘ OR ┌──────────────┐ ┌──────────────┐ │ │───>│ SUB │ │ │ │ (Subscriber) │ │ PUB │ └──────────────┘ │ (Publisher) │ ┌──────────────┐ │ │───>│ SUB │ │ │ │ (Subscriber) │ └──────────────┘ └──────────────┘

Usage Pattern: - Single SUB socket subscribes to multiple PUB publishers (One-to-Many) OR - Multiple SUB sockets subscribe to a single PUB publisher (Many-to-One)

  • Subscribes to specific message topics/types
  • Receives all messages matching subscriptions

SUB/PUB is a One-to-Many communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQXPubXSubProxy for more details.

Source code in aiperf/common/comms/zmq/sub_client.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@implements_protocol(SubClientProtocol)
@CommunicationClientFactory.register(CommClientType.SUB)
class ZMQSubClient(BaseZMQClient):
    """
    ZMQ SUB socket client for subscribing to messages from PUB sockets.
    One-to-Many or Many-to-One communication pattern.

    ASCII Diagram:
    ┌──────────────┐    ┌──────────────┐
    │     PUB      │───>│              │
    │ (Publisher)  │    │              │
    └──────────────┘    │     SUB      │
    ┌──────────────┐    │ (Subscriber) │
    │     PUB      │───>│              │
    │ (Publisher)  │    │              │
    └──────────────┘    └──────────────┘
    OR
    ┌──────────────┐    ┌──────────────┐
    │              │───>│     SUB      │
    │              │    │ (Subscriber) │
    │     PUB      │    └──────────────┘
    │ (Publisher)  │    ┌──────────────┐
    │              │───>│     SUB      │
    │              │    │ (Subscriber) │
    └──────────────┘    └──────────────┘


    Usage Pattern:
    - Single SUB socket subscribes to multiple PUB publishers (One-to-Many)
    OR
    - Multiple SUB sockets subscribe to a single PUB publisher (Many-to-One)

    - Subscribes to specific message topics/types
    - Receives all messages matching subscriptions

    SUB/PUB is a One-to-Many communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQXPubXSubProxy` for more details.
    """

    def __init__(
        self,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        **kwargs,
    ) -> None:
        """
        Initialize the ZMQ Subscriber class.

        Args:
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(zmq.SocketType.SUB, address, bind, socket_ops, **kwargs)

        self._subscribers: dict[MessageTypeT, list[Callable[[Message], Any]]] = {}

    async def subscribe_all(
        self,
        message_callback_map: dict[
            MessageTypeT,
            Callable[[Message], Any] | list[Callable[[Message], Any]],
        ],
    ) -> None:
        """Subscribe to all message_types in the map. For each MessageType, a single
        callback or a list of callbacks can be provided."""
        await self._check_initialized()
        for message_type, callbacks in message_callback_map.items():
            if isinstance(callbacks, list):
                for callback in callbacks:
                    await self._subscribe_internal(message_type, callback)
            else:
                await self._subscribe_internal(message_type, callbacks)

    async def subscribe(
        self, message_type: MessageTypeT, callback: Callable[[Message], Any]
    ) -> None:
        """Subscribe to a message_type.

        Args:
            message_type: MessageTypeT to subscribe to
            callback: Function to call when a message is received (receives Message object)

        Raises:
            Exception if subscription was not successful, None otherwise
        """
        await self._check_initialized()
        await self._subscribe_internal(message_type, callback)

    async def _subscribe_internal(
        self, topic: str, callback: Callable[[Message], Any]
    ) -> None:
        """Subscribe to a message_type.

        Args:
            message_type: MessageTypeT to subscribe to
            callback: Function to call when a message is received (receives Message object)
        """
        try:
            # Only subscribe to topic if this is the first callback for this type
            if topic not in self._subscribers:
                self.debug(lambda: f"Subscribed to topic: {topic}")
                self.socket.setsockopt(
                    zmq.SUBSCRIBE, topic.encode() + TOPIC_END_ENCODED
                )
            else:
                self.debug(
                    lambda: f"Adding callback to existing subscription for topic: {topic}"
                )

            self._subscribers.setdefault(topic, []).append(callback)

        except Exception as e:
            self.exception(f"Exception subscribing to topic {topic}: {e}")
            raise CommunicationError(
                f"Failed to subscribe to topic {topic}: {e}",
            ) from e

    async def _handle_message(self, topic_bytes: bytes, message_bytes: bytes) -> None:
        """Handle a message from a subscribed message_type."""

        # strip the final TOPIC_END chars from the topic
        topic = topic_bytes.decode()[: -len(TOPIC_END)]
        message_json = message_bytes.decode()
        self.trace(
            lambda: f"Received message from topic: '{topic}', message: {message_json}"
        )

        # Targeted messages are in the format "<message_type>.<target_service_id>"
        # or "<message_type>.<target_service_type>"
        # grab the first part which is the message type
        message_type = (
            topic.split(TOPIC_DELIMITER)[0] if TOPIC_DELIMITER in topic else topic
        )

        if message_type == MessageType.COMMAND:
            message = CommandMessage.from_json(message_json)
        elif message_type == MessageType.COMMAND_RESPONSE:
            message = CommandResponse.from_json(message_json)
        else:
            message = Message.from_json_with_type(message_type, message_json)

        self.debug(
            lambda: f"Calling callbacks for message: {message}, {self._subscribers.get(topic)}"
        )

        # Call callbacks with the parsed message object
        if topic in self._subscribers:
            with contextlib.suppress(Exception):  # Ignore errors, they will get logged
                await call_all_functions(self._subscribers[topic], message)

    @background_task(immediate=True, interval=None)
    async def _sub_receiver(self) -> None:
        """Background task for receiving messages from subscribed topics.

        This method is a coroutine that will run indefinitely until the client is
        shutdown. It will wait for messages from the socket and handle them.
        """
        while not self.stop_requested:
            try:
                topic_bytes, message_bytes = await self.socket.recv_multipart()
                if self.is_trace_enabled:
                    self.trace(
                        f"Socket received message: {topic_bytes} {message_bytes}"
                    )
                self.execute_async(self._handle_message(topic_bytes, message_bytes))

            except zmq.Again:
                self.debug(f"Sub client {self.client_id} receiver task timed out")
                await yield_to_event_loop()
            except Exception as e:
                self.exception(
                    f"Exception receiving message from subscription: {e}, {type(e)}"
                )
                await yield_to_event_loop()
            except (asyncio.CancelledError, zmq.ContextTerminated):
                self.debug(f"Sub client {self.client_id} receiver task cancelled")
                break

__init__(address, bind, socket_ops=None, **kwargs)

Initialize the ZMQ Subscriber class.

Parameters:

Name Type Description Default
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/sub_client.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
    **kwargs,
) -> None:
    """
    Initialize the ZMQ Subscriber class.

    Args:
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(zmq.SocketType.SUB, address, bind, socket_ops, **kwargs)

    self._subscribers: dict[MessageTypeT, list[Callable[[Message], Any]]] = {}

subscribe(message_type, callback) async

Subscribe to a message_type.

Parameters:

Name Type Description Default
message_type MessageTypeT

MessageTypeT to subscribe to

required
callback Callable[[Message], Any]

Function to call when a message is received (receives Message object)

required
Source code in aiperf/common/comms/zmq/sub_client.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
async def subscribe(
    self, message_type: MessageTypeT, callback: Callable[[Message], Any]
) -> None:
    """Subscribe to a message_type.

    Args:
        message_type: MessageTypeT to subscribe to
        callback: Function to call when a message is received (receives Message object)

    Raises:
        Exception if subscription was not successful, None otherwise
    """
    await self._check_initialized()
    await self._subscribe_internal(message_type, callback)

subscribe_all(message_callback_map) async

Subscribe to all message_types in the map. For each MessageType, a single callback or a list of callbacks can be provided.

Source code in aiperf/common/comms/zmq/sub_client.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
async def subscribe_all(
    self,
    message_callback_map: dict[
        MessageTypeT,
        Callable[[Message], Any] | list[Callable[[Message], Any]],
    ],
) -> None:
    """Subscribe to all message_types in the map. For each MessageType, a single
    callback or a list of callbacks can be provided."""
    await self._check_initialized()
    for message_type, callbacks in message_callback_map.items():
        if isinstance(callbacks, list):
            for callback in callbacks:
                await self._subscribe_internal(message_type, callback)
        else:
            await self._subscribe_internal(message_type, callbacks)

aiperf.common.comms.zmq.zmq_base_client

BaseZMQClient

Bases: AIPerfLifecycleMixin

Base class for all ZMQ clients. It can be used as-is to create a new ZMQ client, or it can be subclassed to create specific ZMQ client functionality.

It inherits from the :class:AIPerfLifecycleMixin, allowing derived classes to implement specific hooks.

Source code in aiperf/common/comms/zmq/zmq_base_client.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class BaseZMQClient(AIPerfLifecycleMixin):
    """Base class for all ZMQ clients. It can be used as-is to create a new ZMQ client,
    or it can be subclassed to create specific ZMQ client functionality.

    It inherits from the :class:`AIPerfLifecycleMixin`, allowing derived
    classes to implement specific hooks.
    """

    def __init__(
        self,
        socket_type: zmq.SocketType,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        client_id: str | None = None,
        **kwargs,
    ) -> None:
        """
        Initialize the ZMQ Base class.

        Args:
            address (str): The address to bind or connect to.
            bind (bool): Whether to BIND or CONNECT the socket.
            socket_type (SocketType): The type of ZMQ socket (eg. PUB, SUB, ROUTER, DEALER, etc.).
            socket_ops (dict, optional): Additional socket options to set.
        """
        self.context: zmq.asyncio.Context = zmq.asyncio.Context.instance()
        self.socket_type: zmq.SocketType = socket_type
        self.socket: zmq.asyncio.Socket = self.context.socket(self.socket_type)
        self.address: str = address
        self.bind: bool = bind
        self.socket_ops: dict = socket_ops or {}
        self.client_id: str = (
            client_id
            or f"{self.socket_type.name.lower()}_client_{uuid.uuid4().hex[:8]}"
        )
        super().__init__(id=self.client_id, **kwargs)
        self.trace(lambda: f"ZMQ client __init__: {self.client_id}")

    async def _check_initialized(self) -> None:
        """Raise an exception if the socket is not initialized or closed."""
        if self.stop_requested:
            raise asyncio.CancelledError("Socket was stopped")
        if not self.socket:
            raise NotInitializedError("Socket not initialized or closed")

    @property
    def socket_type_name(self) -> str:
        """Get the name of the socket type."""
        return self.socket_type.name

    @on_init
    async def _initialize_socket(self) -> None:
        """Initialize the communication.

        This method will:
        - Create the zmq socket
        - Bind or connect the socket to the address
        - Set the socket options
        - Run the AIPerfHook.ON_INIT hooks
        """
        try:
            self.debug(
                lambda: f"ZMQ {self.socket_type_name} socket initialized, try {'BIND' if self.bind else 'CONNECT'} to {self.address} ({self.client_id})"
            )

            if self.bind:
                self.socket.bind(self.address)
            else:
                self.socket.connect(self.address)

            # Set default timeouts
            self.socket.setsockopt(zmq.RCVTIMEO, ZMQSocketDefaults.RCVTIMEO)
            self.socket.setsockopt(zmq.SNDTIMEO, ZMQSocketDefaults.SNDTIMEO)

            # Set performance-oriented socket options
            self.socket.setsockopt(zmq.TCP_KEEPALIVE, ZMQSocketDefaults.TCP_KEEPALIVE)
            self.socket.setsockopt(
                zmq.TCP_KEEPALIVE_IDLE, ZMQSocketDefaults.TCP_KEEPALIVE_IDLE
            )
            self.socket.setsockopt(
                zmq.TCP_KEEPALIVE_INTVL, ZMQSocketDefaults.TCP_KEEPALIVE_INTVL
            )
            self.socket.setsockopt(
                zmq.TCP_KEEPALIVE_CNT, ZMQSocketDefaults.TCP_KEEPALIVE_CNT
            )
            self.socket.setsockopt(zmq.IMMEDIATE, ZMQSocketDefaults.IMMEDIATE)
            self.socket.setsockopt(zmq.LINGER, ZMQSocketDefaults.LINGER)

            # Set additional socket options requested by the caller
            for key, val in self.socket_ops.items():
                self.socket.setsockopt(key, val)

            self.debug(
                lambda: f"ZMQ {self.socket_type_name} socket {'BOUND' if self.bind else 'CONNECTED'} to {self.address} ({self.client_id})"
            )

        except Exception as e:
            raise InitializationError(f"Failed to initialize ZMQ socket: {e}") from e

    @on_stop
    async def _shutdown_socket(self) -> None:
        """Shutdown the socket."""
        try:
            if self.socket:
                self.socket.close()
        except zmq.ContextTerminated:
            self.debug(
                lambda: f"ZMQ context already terminated, skipping socket close ({self.client_id})"
            )
            return
        except Exception as e:
            self.exception(
                f"Uncaught exception shutting down ZMQ socket: {e} ({self.client_id})"
            )

socket_type_name property

Get the name of the socket type.

__init__(socket_type, address, bind, socket_ops=None, client_id=None, **kwargs)

Initialize the ZMQ Base class.

Parameters:

Name Type Description Default
address str

The address to bind or connect to.

required
bind bool

Whether to BIND or CONNECT the socket.

required
socket_type SocketType

The type of ZMQ socket (eg. PUB, SUB, ROUTER, DEALER, etc.).

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/zmq_base_client.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    socket_type: zmq.SocketType,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
    client_id: str | None = None,
    **kwargs,
) -> None:
    """
    Initialize the ZMQ Base class.

    Args:
        address (str): The address to bind or connect to.
        bind (bool): Whether to BIND or CONNECT the socket.
        socket_type (SocketType): The type of ZMQ socket (eg. PUB, SUB, ROUTER, DEALER, etc.).
        socket_ops (dict, optional): Additional socket options to set.
    """
    self.context: zmq.asyncio.Context = zmq.asyncio.Context.instance()
    self.socket_type: zmq.SocketType = socket_type
    self.socket: zmq.asyncio.Socket = self.context.socket(self.socket_type)
    self.address: str = address
    self.bind: bool = bind
    self.socket_ops: dict = socket_ops or {}
    self.client_id: str = (
        client_id
        or f"{self.socket_type.name.lower()}_client_{uuid.uuid4().hex[:8]}"
    )
    super().__init__(id=self.client_id, **kwargs)
    self.trace(lambda: f"ZMQ client __init__: {self.client_id}")

aiperf.common.comms.zmq.zmq_comms

BaseZMQCommunication

Bases: BaseCommunication, AIPerfLoggerMixin, ABC

ZeroMQ-based implementation of the CommunicationProtocol.

Uses ZeroMQ for publish/subscribe, request/reply, and pull/push patterns to facilitate communication between AIPerf components.

Source code in aiperf/common/comms/zmq/zmq_comms.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@implements_protocol(CommunicationProtocol)
class BaseZMQCommunication(BaseCommunication, AIPerfLoggerMixin, ABC):
    """ZeroMQ-based implementation of the CommunicationProtocol.

    Uses ZeroMQ for publish/subscribe, request/reply, and pull/push patterns to
    facilitate communication between AIPerf components.
    """

    def __init__(
        self,
        config: BaseZMQCommunicationConfig,
    ) -> None:
        super().__init__()
        self.config = config

        self.context = zmq.asyncio.Context.instance()
        self._clients_cache: dict[
            tuple[CommClientType, CommAddressType, bool], CommunicationClientProtocol
        ] = {}

        self.debug(f"ZMQ communication using protocol: {type(self.config).__name__}")

    def get_address(self, address_type: CommAddressType) -> str:
        """Get the actual address based on the address type from the config."""
        if isinstance(address_type, CommAddress):
            return self.config.get_address(address_type)
        return address_type

    def create_client(
        self,
        client_type: CommClientType,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
        max_pull_concurrency: int | None = None,
        **kwargs,
    ) -> CommunicationClientProtocol:
        """Create a communication client for a given client type and address.

        Args:
            client_type: The type of client to create.
            address: The type of address to use when looking up in the communication config, or the address itself.
            bind: Whether to bind or connect the socket.
            socket_ops: Additional socket options to set.
            max_pull_concurrency: The maximum number of concurrent pull requests to allow. (Only used for pull clients)
        """
        if (client_type, address, bind) in self._clients_cache:
            return self._clients_cache[(client_type, address, bind)]

        if self.state != LifecycleState.CREATED:
            # We require the clients to be created before the communication class is initialized.
            # This is because this class manages the lifecycle of the clients of as well.
            raise InvalidStateError(
                f"Communication clients must be created before the {self.__class__.__name__} "
                f"class is initialized: {self.state!r}"
            )

        client = CommunicationClientFactory.create_instance(
            client_type,
            address=self.get_address(address),
            bind=bind,
            socket_ops=socket_ops,
            max_pull_concurrency=max_pull_concurrency,
            **kwargs,
        )

        self._clients_cache[(client_type, address, bind)] = client
        self.attach_child_lifecycle(client)
        return client

create_client(client_type, address, bind=False, socket_ops=None, max_pull_concurrency=None, **kwargs)

Create a communication client for a given client type and address.

Parameters:

Name Type Description Default
client_type CommClientType

The type of client to create.

required
address CommAddressType

The type of address to use when looking up in the communication config, or the address itself.

required
bind bool

Whether to bind or connect the socket.

False
socket_ops dict | None

Additional socket options to set.

None
max_pull_concurrency int | None

The maximum number of concurrent pull requests to allow. (Only used for pull clients)

None
Source code in aiperf/common/comms/zmq/zmq_comms.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def create_client(
    self,
    client_type: CommClientType,
    address: CommAddressType,
    bind: bool = False,
    socket_ops: dict | None = None,
    max_pull_concurrency: int | None = None,
    **kwargs,
) -> CommunicationClientProtocol:
    """Create a communication client for a given client type and address.

    Args:
        client_type: The type of client to create.
        address: The type of address to use when looking up in the communication config, or the address itself.
        bind: Whether to bind or connect the socket.
        socket_ops: Additional socket options to set.
        max_pull_concurrency: The maximum number of concurrent pull requests to allow. (Only used for pull clients)
    """
    if (client_type, address, bind) in self._clients_cache:
        return self._clients_cache[(client_type, address, bind)]

    if self.state != LifecycleState.CREATED:
        # We require the clients to be created before the communication class is initialized.
        # This is because this class manages the lifecycle of the clients of as well.
        raise InvalidStateError(
            f"Communication clients must be created before the {self.__class__.__name__} "
            f"class is initialized: {self.state!r}"
        )

    client = CommunicationClientFactory.create_instance(
        client_type,
        address=self.get_address(address),
        bind=bind,
        socket_ops=socket_ops,
        max_pull_concurrency=max_pull_concurrency,
        **kwargs,
    )

    self._clients_cache[(client_type, address, bind)] = client
    self.attach_child_lifecycle(client)
    return client

get_address(address_type)

Get the actual address based on the address type from the config.

Source code in aiperf/common/comms/zmq/zmq_comms.py
54
55
56
57
58
def get_address(self, address_type: CommAddressType) -> str:
    """Get the actual address based on the address type from the config."""
    if isinstance(address_type, CommAddress):
        return self.config.get_address(address_type)
    return address_type

ZMQIPCCommunication

Bases: BaseZMQCommunication

ZeroMQ-based implementation of the Communication interface using IPC transport.

Source code in aiperf/common/comms/zmq/zmq_comms.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
@CommunicationFactory.register(CommunicationBackend.ZMQ_IPC)
@implements_protocol(CommunicationProtocol)
class ZMQIPCCommunication(BaseZMQCommunication):
    """ZeroMQ-based implementation of the Communication interface using IPC transport."""

    def __init__(self, config: ZMQIPCConfig | None = None) -> None:
        """Initialize ZMQ IPC communication.

        Args:
            config: ZMQIPCConfig object with configuration parameters
        """
        super().__init__(config or ZMQIPCConfig())
        # call after super init so that way self.config is set
        self._setup_ipc_directory()

    def _setup_ipc_directory(self) -> None:
        """Create IPC socket directory if using IPC transport."""
        self._ipc_socket_dir = Path(self.config.path)
        if not self._ipc_socket_dir.exists():
            self.debug(
                f"IPC socket directory does not exist, creating: {self._ipc_socket_dir}"
            )
            self._ipc_socket_dir.mkdir(parents=True, exist_ok=True)
            self.debug(f"Created IPC socket directory: {self._ipc_socket_dir}")
        else:
            self.debug(f"IPC socket directory already exists: {self._ipc_socket_dir}")

    @on_stop
    def _cleanup_ipc_sockets(self) -> None:
        """Clean up IPC socket files."""
        if self._ipc_socket_dir and self._ipc_socket_dir.exists():
            # Remove all .ipc files in the directory
            ipc_files = glob.glob(str(self._ipc_socket_dir / "*.ipc"))
            for ipc_file in ipc_files:
                try:
                    if os.path.exists(ipc_file):
                        os.unlink(ipc_file)
                        self.debug(f"Removed IPC socket file: {ipc_file}")
                except OSError as e:
                    if e.errno != errno.ENOENT:
                        self.warning(
                            f"Failed to remove IPC socket file {ipc_file}: {e}"
                        )

__init__(config=None)

Initialize ZMQ IPC communication.

Parameters:

Name Type Description Default
config ZMQIPCConfig | None

ZMQIPCConfig object with configuration parameters

None
Source code in aiperf/common/comms/zmq/zmq_comms.py
122
123
124
125
126
127
128
129
130
def __init__(self, config: ZMQIPCConfig | None = None) -> None:
    """Initialize ZMQ IPC communication.

    Args:
        config: ZMQIPCConfig object with configuration parameters
    """
    super().__init__(config or ZMQIPCConfig())
    # call after super init so that way self.config is set
    self._setup_ipc_directory()

ZMQTCPCommunication

Bases: BaseZMQCommunication

ZeroMQ-based implementation of the Communication interface using TCP transport.

Source code in aiperf/common/comms/zmq/zmq_comms.py
103
104
105
106
107
108
109
110
111
112
113
114
@CommunicationFactory.register(CommunicationBackend.ZMQ_TCP)
@implements_protocol(CommunicationProtocol)
class ZMQTCPCommunication(BaseZMQCommunication):
    """ZeroMQ-based implementation of the Communication interface using TCP transport."""

    def __init__(self, config: ZMQTCPConfig | None = None) -> None:
        """Initialize ZMQ TCP communication.

        Args:
            config: ZMQTCPTransportConfig object with configuration parameters
        """
        super().__init__(config or ZMQTCPConfig())

__init__(config=None)

Initialize ZMQ TCP communication.

Parameters:

Name Type Description Default
config ZMQTCPConfig | None

ZMQTCPTransportConfig object with configuration parameters

None
Source code in aiperf/common/comms/zmq/zmq_comms.py
108
109
110
111
112
113
114
def __init__(self, config: ZMQTCPConfig | None = None) -> None:
    """Initialize ZMQ TCP communication.

    Args:
        config: ZMQTCPTransportConfig object with configuration parameters
    """
    super().__init__(config or ZMQTCPConfig())

aiperf.common.comms.zmq.zmq_defaults

TOPIC_DELIMITER = '.' module-attribute

The delimiter between topic parts. This is used to create an inverted hierarchy of topics for filtering by service type or service id.

For example: - "command" - "system_controller.command" - "timing_manager_eff34565.command"

TOPIC_END = '$' module-attribute

This is used to add to the end of each topic to prevent the topic from being a prefix of another topic. This is required for the PUB/SUB pattern to work correctly, otherwise topics like "command_response" will be received by the "command" subscriber as well.

For example: - "command$" - "command_response$"

TOPIC_END_ENCODED = TOPIC_END.encode() module-attribute

The encoded version of TOPIC_END.

ZMQSocketDefaults

Default values for ZMQ sockets.

Source code in aiperf/common/comms/zmq/zmq_defaults.py
29
30
31
32
33
34
35
36
37
38
39
40
class ZMQSocketDefaults:
    """Default values for ZMQ sockets."""

    # Socket Options
    RCVTIMEO = 300000  # 5 minutes
    SNDTIMEO = 300000  # 5 minutes
    TCP_KEEPALIVE = 1
    TCP_KEEPALIVE_IDLE = 60
    TCP_KEEPALIVE_INTVL = 10
    TCP_KEEPALIVE_CNT = 3
    IMMEDIATE = 1  # Don't queue messages
    LINGER = 0  # Don't wait on close

aiperf.common.comms.zmq.zmq_proxy_base

BaseZMQProxy

Bases: AIPerfLifecycleMixin, ABC

A Base ZMQ Proxy class.

  • Frontend and backend sockets forward messages bidirectionally
    • Frontend and Backend sockets both BIND
  • Multiple clients CONNECT to frontend_address
  • Multiple services CONNECT to backend_address
  • Control: Optional REP socket for proxy commands (start/stop/pause) - not implemented yet
  • Monitoring: Optional PUB socket that broadcasts copies of all forwarded messages - not implemented yet
  • Proxy runs in separate thread to avoid blocking main event loop
Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class BaseZMQProxy(AIPerfLifecycleMixin, ABC):
    """
    A Base ZMQ Proxy class.

    - Frontend and backend sockets forward messages bidirectionally
        - Frontend and Backend sockets both BIND
    - Multiple clients CONNECT to `frontend_address`
    - Multiple services CONNECT to `backend_address`
    - Control: Optional REP socket for proxy commands (start/stop/pause) - not implemented yet
    - Monitoring: Optional PUB socket that broadcasts copies of all forwarded messages - not implemented yet
    - Proxy runs in separate thread to avoid blocking main event loop
    """

    def __init__(
        self,
        frontend_socket_class: type[BaseZMQClient],
        backend_socket_class: type[BaseZMQClient],
        zmq_proxy_config: BaseZMQProxyConfig,
        socket_ops: dict | None = None,
        proxy_uuid: str | None = None,
    ) -> None:
        """Initialize the ZMQ Proxy. This is a base class for all ZMQ Proxies.

        Args:
            frontend_socket_class (type[BaseZMQClient]): The frontend socket class.
            backend_socket_class (type[BaseZMQClient]): The backend socket class.
            zmq_proxy_config (BaseZMQProxyConfig): The ZMQ proxy configuration.
            socket_ops (dict, optional): Additional socket options to set.
            proxy_uuid (str, optional): An optional UUID for the proxy instance. If not provided,
                a new UUID will be generated. This is useful for tracing and debugging purposes.
        """

        self.proxy_uuid = proxy_uuid or uuid.uuid4().hex[:8]
        self.proxy_id = f"{self.__class__.__name__.lower()}_{self.proxy_uuid}"
        super().__init__()
        self.context = zmq.asyncio.Context.instance()
        self.socket_ops = socket_ops

        self.monitor_task: asyncio.Task | None = None
        self.proxy_task: asyncio.Task | None = None
        self.control_client: ProxySocketClient | None = None
        self.capture_client: ProxySocketClient | None = None

        self.frontend_address = zmq_proxy_config.frontend_address
        self.backend_address = zmq_proxy_config.backend_address
        self.control_address = zmq_proxy_config.control_address
        self.capture_address = zmq_proxy_config.capture_address

        self.debug(
            lambda: f"Proxy Initializing - Frontend: {self.frontend_address}, Backend: {self.backend_address}"
        )

        self.backend_socket = backend_socket_class(
            address=self.backend_address,
            socket_ops=self.socket_ops,
            proxy_uuid=self.proxy_uuid,  # Pass the proxy UUID for tracing
        )  # type: ignore

        self.frontend_socket = frontend_socket_class(
            address=self.frontend_address,
            socket_ops=self.socket_ops,
            proxy_uuid=self.proxy_uuid,  # Pass the proxy UUID for tracing
        )  # type: ignore

        if self.control_address:
            self.debug(lambda: f"Proxy Control - Address: {self.control_address}")
            self.control_client = ProxySocketClient(
                socket_type=SocketType.REP,
                address=self.control_address,
                socket_ops=self.socket_ops,
                end_type=ProxyEndType.Control,
                proxy_uuid=self.proxy_uuid,
            )

        if self.capture_address:
            self.debug(lambda: f"Proxy Capture - Address: {self.capture_address}")
            self.capture_client = ProxySocketClient(
                socket_type=SocketType.PUB,
                address=self.capture_address,
                socket_ops=self.socket_ops,
                end_type=ProxyEndType.Capture,
                proxy_uuid=self.proxy_uuid,
            )

    @classmethod
    @abstractmethod
    def from_config(
        cls,
        config: BaseZMQProxyConfig | None,
        socket_ops: dict | None = None,
    ) -> "BaseZMQProxy | None":
        """Create a BaseZMQProxy from a BaseZMQProxyConfig, or None if not provided."""
        ...

    @on_init
    async def _initialize(self) -> None:
        """Initialize and start the BaseZMQProxy."""
        self.debug("Proxy Initializing Sockets...")
        self.debug(
            lambda: f"Frontend {self.frontend_socket.socket_type.name} socket binding to: {self.frontend_address} (for {self.backend_socket.socket_type.name} clients)"
        )
        self.debug(
            lambda: f"Backend {self.backend_socket.socket_type.name} socket binding to: {self.backend_address} (for {self.frontend_socket.socket_type.name} services)"
        )

        try:
            exceptions = await asyncio.gather(
                self.backend_socket.initialize(),
                self.frontend_socket.initialize(),
                *[
                    client.initialize()
                    for client in [self.control_client, self.capture_client]
                    if client
                ],
                return_exceptions=True,
            )
            if any(exceptions):
                self.exception(f"Proxy Socket Initialization Failed: {exceptions}")
                raise

            self.debug("Proxy Sockets Initialized Successfully")

            if self.control_client:
                self.debug(lambda: f"Control socket bound to: {self.control_address}")
            if self.capture_client:
                self.debug(lambda: f"Capture socket bound to: {self.capture_address}")

        except Exception as e:
            self.exception(f"Proxy Socket Initialization Failed: {e}")
            raise

    @on_start
    async def _start_proxy(self) -> None:
        """Start the Base ZMQ Proxy.

        This method starts the proxy and waits for it to complete asynchronously.
        The proxy forwards messages between the frontend and backend sockets.

        Raises:
            ProxyError: If the proxy produces an error.
        """
        self.debug("Starting Proxy...")

        self.proxy_task = asyncio.create_task(
            asyncio.to_thread(
                zmq.proxy_steerable,
                self.frontend_socket.socket,
                self.backend_socket.socket,
                capture=self.capture_client.socket if self.capture_client else None,
                control=self.control_client.socket if self.control_client else None,
            )
        )

    @background_task(immediate=True, interval=None)
    async def _monitor_messages(self) -> None:
        """Monitor messages flowing through the proxy via the capture socket."""
        if not self.capture_client or not self.capture_address:
            self.debug("Proxy Monitor Not Enabled")
            return

        self.debug(
            lambda: f"Proxy Monitor Starting - Capture Address: {self.capture_address}"
        )

        capture_socket = self.context.socket(SocketType.SUB)
        capture_socket.connect(self.capture_address)
        self.debug(
            lambda: f"Proxy Monitor Connected to Capture Address: {self.capture_address}"
        )
        capture_socket.setsockopt(zmq.SUBSCRIBE, b"")  # Subscribe to all messages
        self.debug("Proxy Monitor Subscribed to all messages")

        try:
            while not self.stop_requested:
                recv_msg = await capture_socket.recv_multipart()
                self.debug(lambda msg=recv_msg: f"Proxy Monitor Received: {msg}")
        except Exception as e:
            self.exception(f"Proxy Monitor Error - {e}")
            raise
        except asyncio.CancelledError:
            return
        finally:
            capture_socket.close()

    @on_stop
    async def _stop_proxy(self) -> None:
        """Shutdown the BaseZMQProxy."""
        self.debug("Proxy Stopping...")
        if self.proxy_task:
            self.proxy_task.cancel()
            self.proxy_task = None
        await self.frontend_socket.stop()
        await self.backend_socket.stop()
        if self.control_client:
            await self.control_client.stop()
        if self.capture_client:
            await self.capture_client.stop()

__init__(frontend_socket_class, backend_socket_class, zmq_proxy_config, socket_ops=None, proxy_uuid=None)

Initialize the ZMQ Proxy. This is a base class for all ZMQ Proxies.

Parameters:

Name Type Description Default
frontend_socket_class type[BaseZMQClient]

The frontend socket class.

required
backend_socket_class type[BaseZMQClient]

The backend socket class.

required
zmq_proxy_config BaseZMQProxyConfig

The ZMQ proxy configuration.

required
socket_ops dict

Additional socket options to set.

None
proxy_uuid str

An optional UUID for the proxy instance. If not provided, a new UUID will be generated. This is useful for tracing and debugging purposes.

None
Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __init__(
    self,
    frontend_socket_class: type[BaseZMQClient],
    backend_socket_class: type[BaseZMQClient],
    zmq_proxy_config: BaseZMQProxyConfig,
    socket_ops: dict | None = None,
    proxy_uuid: str | None = None,
) -> None:
    """Initialize the ZMQ Proxy. This is a base class for all ZMQ Proxies.

    Args:
        frontend_socket_class (type[BaseZMQClient]): The frontend socket class.
        backend_socket_class (type[BaseZMQClient]): The backend socket class.
        zmq_proxy_config (BaseZMQProxyConfig): The ZMQ proxy configuration.
        socket_ops (dict, optional): Additional socket options to set.
        proxy_uuid (str, optional): An optional UUID for the proxy instance. If not provided,
            a new UUID will be generated. This is useful for tracing and debugging purposes.
    """

    self.proxy_uuid = proxy_uuid or uuid.uuid4().hex[:8]
    self.proxy_id = f"{self.__class__.__name__.lower()}_{self.proxy_uuid}"
    super().__init__()
    self.context = zmq.asyncio.Context.instance()
    self.socket_ops = socket_ops

    self.monitor_task: asyncio.Task | None = None
    self.proxy_task: asyncio.Task | None = None
    self.control_client: ProxySocketClient | None = None
    self.capture_client: ProxySocketClient | None = None

    self.frontend_address = zmq_proxy_config.frontend_address
    self.backend_address = zmq_proxy_config.backend_address
    self.control_address = zmq_proxy_config.control_address
    self.capture_address = zmq_proxy_config.capture_address

    self.debug(
        lambda: f"Proxy Initializing - Frontend: {self.frontend_address}, Backend: {self.backend_address}"
    )

    self.backend_socket = backend_socket_class(
        address=self.backend_address,
        socket_ops=self.socket_ops,
        proxy_uuid=self.proxy_uuid,  # Pass the proxy UUID for tracing
    )  # type: ignore

    self.frontend_socket = frontend_socket_class(
        address=self.frontend_address,
        socket_ops=self.socket_ops,
        proxy_uuid=self.proxy_uuid,  # Pass the proxy UUID for tracing
    )  # type: ignore

    if self.control_address:
        self.debug(lambda: f"Proxy Control - Address: {self.control_address}")
        self.control_client = ProxySocketClient(
            socket_type=SocketType.REP,
            address=self.control_address,
            socket_ops=self.socket_ops,
            end_type=ProxyEndType.Control,
            proxy_uuid=self.proxy_uuid,
        )

    if self.capture_address:
        self.debug(lambda: f"Proxy Capture - Address: {self.capture_address}")
        self.capture_client = ProxySocketClient(
            socket_type=SocketType.PUB,
            address=self.capture_address,
            socket_ops=self.socket_ops,
            end_type=ProxyEndType.Capture,
            proxy_uuid=self.proxy_uuid,
        )

from_config(config, socket_ops=None) abstractmethod classmethod

Create a BaseZMQProxy from a BaseZMQProxyConfig, or None if not provided.

Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
138
139
140
141
142
143
144
145
146
@classmethod
@abstractmethod
def from_config(
    cls,
    config: BaseZMQProxyConfig | None,
    socket_ops: dict | None = None,
) -> "BaseZMQProxy | None":
    """Create a BaseZMQProxy from a BaseZMQProxyConfig, or None if not provided."""
    ...

ProxySocketClient

Bases: BaseZMQClient

A ZMQ Proxy socket client class that extends BaseZMQClient.

This class is used to create proxy sockets for the frontend, backend, capture, and control endpoint types of a ZMQ Proxy.

Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class ProxySocketClient(BaseZMQClient):
    """A ZMQ Proxy socket client class that extends BaseZMQClient.

    This class is used to create proxy sockets for the frontend, backend, capture, and control
    endpoint types of a ZMQ Proxy.
    """

    def __init__(
        self,
        socket_type: SocketType,
        address: str,
        end_type: ProxyEndType,
        socket_ops: dict | None = None,
        proxy_uuid: str | None = None,
    ) -> None:
        self.client_id = f"proxy_{end_type}_{socket_type.name.lower()}_{proxy_uuid or uuid.uuid4().hex[:8]}"
        super().__init__(
            socket_type,
            address,
            bind=True,
            socket_ops=socket_ops,
            client_id=self.client_id,
        )
        self.debug(
            lambda: f"ZMQ Proxy {end_type.name} {socket_type.name} - Address: {address}"
        )

aiperf.common.comms.zmq.zmq_proxy_sockets

ZMQDealerRouterProxy = define_proxy_class(ZMQProxyType.DEALER_ROUTER, create_proxy_socket_class(SocketType.ROUTER, ProxyEndType.Frontend), create_proxy_socket_class(SocketType.DEALER, ProxyEndType.Backend)) module-attribute

A ROUTER socket for the proxy's frontend and a DEALER socket for the proxy's backend.

ASCII Diagram: ┌───────────┐ ┌──────────────────────────────────┐ ┌───────────┐ │ DEALER │<───>│ PROXY │<────>│ ROUTER │ │ Client 1 │ │ ┌──────────┐ ┌──────────┐ │ │ Service 1 │ └───────────┘ │ │ ROUTER │<─────> │ DEALER │ │ └───────────┘ ┌───────────┐ │ │ Frontend │ │ Backend │ │ ┌───────────┐ │ DEALER │<───>│ └──────────┘ └──────────┘ │<────>│ ROUTER │ │ Client N │ └──────────────────────────────────┘ │ Service N │ └───────────┘ └───────────┘

The ROUTER frontend socket receives messages from DEALER clients and forwards them through the proxy to ROUTER services. The ZMQ proxy handles the message routing automatically.

The DEALER backend socket receives messages from ROUTER services and forwards them through the proxy to DEALER clients. The ZMQ proxy handles the message routing automatically.

CRITICAL: This socket must NOT have an identity when used in a proxy configuration, as it needs to be transparent to preserve routing envelopes for proper response forwarding back to original DEALER clients.

ZMQPushPullProxy = define_proxy_class(ZMQProxyType.PUSH_PULL, create_proxy_socket_class(SocketType.PULL, ProxyEndType.Frontend), create_proxy_socket_class(SocketType.PUSH, ProxyEndType.Backend)) module-attribute

A PULL socket for the proxy's frontend and a PUSH socket for the proxy's backend.

ASCII Diagram: ┌───────────┐ ┌─────────────────────────────────┐ ┌───────────┐ │ PUSH │─────>│ PROXY │─────>│ PULL │ │ Client 1 │ │ ┌──────────┐ ┌──────────┐ │ │ Service 1 │ └───────────┘ │ │ PULL │──────>│ PUSH │ │ └───────────┘ ┌───────────┐ │ │ Frontend │ │ Backend │ │ ┌───────────┐ │ PUSH │─────>│ └──────────┘ └──────────┘ │─────>│ PULL │ │ Client N │ └─────────────────────────────────┘ │ Service N │ └───────────┘ └───────────┘

The PULL frontend socket receives messages from PUSH clients and forwards them through the proxy to PUSH services. The ZMQ proxy handles the message routing automatically.

The PUSH backend socket forwards messages from the proxy to PULL services. The ZMQ proxy handles the message routing automatically.

ZMQXPubXSubProxy = define_proxy_class(ZMQProxyType.XPUB_XSUB, create_proxy_socket_class(SocketType.XSUB, ProxyEndType.Frontend), create_proxy_socket_class(SocketType.XPUB, ProxyEndType.Backend)) module-attribute

An XSUB socket for the proxy's frontend and an XPUB socket for the proxy's backend.

ASCII Diagram: ┌───────────┐ ┌─────────────────────────────────┐ ┌───────────┐ │ PUB │───>│ PROXY │───>│ SUB │ │ Client 1 │ │ ┌──────────┐ ┌──────────┐ │ │ Service 1 │ └───────────┘ │ │ XSUB │──────>│ XPUB │ │ └───────────┘ ┌───────────┐ │ │ Frontend │ │ Backend │ │ ┌───────────┐ │ PUB │───>│ └──────────┘ └──────────┘ │───>│ SUB │ │ Client N │ └─────────────────────────────────┘ │ Service N │ └───────────┘ └───────────┘

The XSUB frontend socket receives messages from PUB clients and forwards them through the proxy to XPUB services. The ZMQ proxy handles the message routing automatically.

The XPUB backend socket forwards messages from the proxy to SUB services. The ZMQ proxy handles the message routing automatically.

create_proxy_socket_class(socket_type, end_type)

Create a proxy socket class using the specified socket type. This is used to reduce the boilerplate code required to create a ZMQ Proxy class.

Source code in aiperf/common/comms/zmq/zmq_proxy_sockets.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def create_proxy_socket_class(
    socket_type: SocketType,
    end_type: ProxyEndType,
) -> type[BaseZMQClient]:
    """Create a proxy socket class using the specified socket type. This is used to
    reduce the boilerplate code required to create a ZMQ Proxy class.
    """

    class_name = f"ZMQProxy{end_type.name}Socket{socket_type.name}"

    class ProxySocket(ProxySocketClient):
        """A ZMQ Proxy socket class with a specific socket type."""

        def __init__(
            self,
            address: str,
            socket_ops: dict | None = None,
            proxy_uuid: str | None = None,
        ):
            """Initialize the ZMQ Proxy socket class."""

            super().__init__(
                socket_type,
                address,
                end_type=end_type,
                socket_ops=socket_ops,
                proxy_uuid=proxy_uuid,
            )

        @on_init
        async def _initialize_socket(self) -> None:
            """Initialize the socket with proper configuration for XPUB/XSUB proxy."""
            if self.socket_type == SocketType.XPUB:
                self.socket.setsockopt(zmq.XPUB_VERBOSE, 1)
                self.debug(
                    lambda: "XPUB socket configured with XPUB_VERBOSE=1 for subscription forwarding"
                )

    # Dynamically set the class name and qualname based on the socket and end type
    ProxySocket.__name__ = class_name
    ProxySocket.__qualname__ = class_name
    ProxySocket.__doc__ = f"A ZMQ Proxy {end_type.name} socket implementation."
    return ProxySocket

define_proxy_class(proxy_type, frontend_socket_class, backend_socket_class)

This function reduces the boilerplate code required to create a ZMQ Proxy class. It will generate a ZMQ Proxy class and register it with the ZMQProxyFactory.

Parameters:

Name Type Description Default
proxy_type ZMQProxyType

The type of proxy to generate.

required
frontend_socket_class type[BaseZMQClient]

The class of the frontend socket.

required
backend_socket_class type[BaseZMQClient]

The class of the backend socket.

required
Source code in aiperf/common/comms/zmq/zmq_proxy_sockets.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def define_proxy_class(
    proxy_type: ZMQProxyType,
    frontend_socket_class: type[BaseZMQClient],
    backend_socket_class: type[BaseZMQClient],
) -> type[BaseZMQProxy]:
    """This function reduces the boilerplate code required to create a ZMQ Proxy class.
    It will generate a ZMQ Proxy class and register it with the ZMQProxyFactory.

    Args:
        proxy_type: The type of proxy to generate.
        frontend_socket_class: The class of the frontend socket.
        backend_socket_class: The class of the backend socket.
    """

    class ZMQProxy(BaseZMQProxy):
        """
        A Generated ZMQ Proxy class.

        This class is responsible for creating the ZMQ proxy that forwards messages
        between frontend and backend sockets.
        """

        def __init__(
            self,
            zmq_proxy_config: BaseZMQProxyConfig,
            socket_ops: dict | None = None,
        ) -> None:
            super().__init__(
                frontend_socket_class=frontend_socket_class,
                backend_socket_class=backend_socket_class,
                zmq_proxy_config=zmq_proxy_config,
                socket_ops=socket_ops,
            )

        @classmethod
        def from_config(
            cls,
            config: BaseZMQProxyConfig | None,
            socket_ops: dict | None = None,
        ) -> "ZMQProxy | None":
            if config is None:
                return None
            return cls(
                zmq_proxy_config=config,
                socket_ops=socket_ops,
            )

    # Dynamically set the class name and qualname based on the proxy type
    ZMQProxy.__name__ = f"ZMQ_{proxy_type.name}_Proxy"
    ZMQProxy.__qualname__ = ZMQProxy.__name__
    ZMQProxy.__doc__ = f"A ZMQ Proxy for {proxy_type.name} communication."
    ZMQProxyFactory.register(proxy_type)(ZMQProxy)
    return ZMQProxy

aiperf.common.config.audio_config

AudioConfig

Bases: BaseConfig

A configuration class for defining audio related settings.

Source code in aiperf/common/config/audio_config.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class AudioConfig(BaseConfig):
    """
    A configuration class for defining audio related settings.
    """

    _CLI_GROUP = Groups.AUDIO_INPUT

    batch_size: Annotated[
        int,
        Field(
            ge=0,
            description="The batch size of audio requests AIPerf should send.\n"
            "This is currently supported with the OpenAI `chat` endpoint type",
        ),
        Parameter(
            name=(
                "--audio-batch-size",
                "--batch-size-audio",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = AudioDefaults.BATCH_SIZE

    length: AudioLengthConfig = AudioLengthConfig()

    format: Annotated[
        AudioFormat,
        Field(
            description="The format of the audio files (wav or mp3).",
        ),
        Parameter(
            name=(
                "--audio-format",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = AudioDefaults.FORMAT

    depths: Annotated[
        list[int],
        Field(
            min_length=1,
            description="A list of audio bit depths to randomly select from in bits.",
        ),
        BeforeValidator(parse_str_or_list_of_positive_values),
        Parameter(
            name=(
                "--audio-depths",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = AudioDefaults.DEPTHS

    sample_rates: Annotated[
        list[float],
        Field(
            min_length=1,
            description="A list of audio sample rates to randomly select from in kHz.\n"
            "Common sample rates are 16, 44.1, 48, 96, etc.",
        ),
        BeforeValidator(parse_str_or_list_of_positive_values),
        Parameter(
            name=(
                "--audio-sample-rates",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = AudioDefaults.SAMPLE_RATES

    num_channels: Annotated[
        int,
        Field(
            ge=1,
            le=2,
            description="The number of audio channels to use for the audio data generation.",
        ),
        Parameter(
            name=(
                "--audio-num-channels",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = AudioDefaults.NUM_CHANNELS

AudioLengthConfig

Bases: BaseConfig

A configuration class for defining audio length related settings.

Source code in aiperf/common/config/audio_config.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class AudioLengthConfig(BaseConfig):
    """
    A configuration class for defining audio length related settings.
    """

    _CLI_GROUP = Groups.AUDIO_INPUT

    mean: Annotated[
        float,
        Field(
            ge=0,
            description="The mean length of the audio in seconds.",
        ),
        Parameter(
            name=(
                "--audio-length-mean",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = AudioDefaults.LENGTH_MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of the length of the audio in seconds.",
        ),
        Parameter(
            name=(
                "--audio-length-stddev",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = AudioDefaults.LENGTH_STDDEV

aiperf.common.config.base_config

BaseConfig

Bases: AIPerfBaseModel

Base configuration class for all configurations.

Source code in aiperf/common/config/base_config.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class BaseConfig(AIPerfBaseModel):
    """
    Base configuration class for all configurations.
    """

    def serialize_to_yaml(self, verbose: bool = False, indent: int = 4) -> str:
        """
        Serialize a Pydantic model to a YAML string.

        Args:
            verbose: Whether to include verbose comments in the YAML output.
            indent: The per-level indentation to use.
        """
        # Dump model to dict with context (flags propagate recursively)
        context = {
            "verbose": verbose,
        }

        data = self.model_dump(context=context)

        # Attach comments recursively
        commented_data = self._attach_comments(
            data=data,
            model=self,
            context=context,
            indent=indent,
        )

        # Dump to YAML
        yaml = YAML(pure=True)
        yaml.indent(mapping=indent, sequence=indent, offset=indent)

        stream = io.StringIO()
        yaml.dump(commented_data, stream)
        return stream.getvalue()

    @staticmethod
    def _attach_comments(
        data: Any,
        model: AIPerfBaseModel,
        context: dict,
        indent: int,
        indent_level: int = 0,
    ) -> Any:
        """
        Recursively convert dicts to ruamel.yaml CommentedMap and attach comments from
        Pydantic field descriptions, or based on context (e.g., verbose flag).

        Args:
            data: The raw data to convert to a CommentedMap.
            model: The Pydantic model that contains the field descriptions.
            context: The Pydantic serializer context which contains the serializer flags.
            indent: The per-level indentation to use for the comments.
            indent_level: The current level of indentation. The actual indentation is
                `indent * indent_level`.

        Returns:
            The data with comments attached.
        """
        if isinstance(data, dict):
            # Create a CommentedMap to store the commented data. This is a special type of
            # dict provided by the ruamel.yaml library that preserves the order of the keys and
            # allows for comments to be attached to the keys.
            commented_map = CommentedMap()

            for field_name, value in data.items():
                field = model.__class__.model_fields.get(field_name)

                if not BaseConfig._should_add_field_to_template(field):
                    continue

                if BaseConfig._is_a_nested_config(field, value):
                    # Recursively process nested models
                    commented_map[field_name] = BaseConfig._attach_comments(
                        value,
                        getattr(model, field_name),
                        context=context,
                        indent=indent,
                        indent_level=indent_level + 1,
                    )

                    commented_map.yaml_set_comment_before_after_key(
                        field_name,
                        before="\n",
                        indent=indent * (indent_level + 1),
                    )
                else:
                    # Attach the value to the commented map
                    commented_map[field_name] = BaseConfig._preprocess_value(value)

                # Attach comment if verbose and description exists
                if context.get("verbose") and field and field.description:
                    # Set the comment before the key, with the specified indentation
                    commented_map.yaml_set_comment_before_after_key(
                        field_name,
                        before="\n" + field.description,
                        indent=indent * indent_level,
                    )

            return commented_map

    @staticmethod
    def _should_add_field_to_template(field: Any) -> bool:
        # Check if the field should be added to the template based on json_schema_extra
        # and the add_to_template flag.
        # If add_to_template is False, we skip adding the field to the template.
        # If add_to_template is True or not present, we include the field in the template.
        if field and field.json_schema_extra:
            return field.json_schema_extra.get(ADD_TO_TEMPLATE, True)
        else:
            return True

    @staticmethod
    def _is_a_nested_config(field: Any, value: Any) -> bool:
        return (
            isinstance(value, dict)
            and field
            and issubclass(field.annotation, AIPerfBaseModel)
        )

    @staticmethod
    def _preprocess_value(value: Any) -> Any:
        """
        Preprocess the value before serialization.
        """

        if isinstance(value, Enum):
            return str(value.value).lower()
        elif isinstance(value, Path):
            return str(value)
        else:
            return value

serialize_to_yaml(verbose=False, indent=4)

Serialize a Pydantic model to a YAML string.

Parameters:

Name Type Description Default
verbose bool

Whether to include verbose comments in the YAML output.

False
indent int

The per-level indentation to use.

4
Source code in aiperf/common/config/base_config.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def serialize_to_yaml(self, verbose: bool = False, indent: int = 4) -> str:
    """
    Serialize a Pydantic model to a YAML string.

    Args:
        verbose: Whether to include verbose comments in the YAML output.
        indent: The per-level indentation to use.
    """
    # Dump model to dict with context (flags propagate recursively)
    context = {
        "verbose": verbose,
    }

    data = self.model_dump(context=context)

    # Attach comments recursively
    commented_data = self._attach_comments(
        data=data,
        model=self,
        context=context,
        indent=indent,
    )

    # Dump to YAML
    yaml = YAML(pure=True)
    yaml.indent(mapping=indent, sequence=indent, offset=indent)

    stream = io.StringIO()
    yaml.dump(commented_data, stream)
    return stream.getvalue()

aiperf.common.config.config_defaults

aiperf.common.config.config_validators

parse_file(value)

Parses the given string value and returns a Path object if the value represents a valid file or directory. Returns None if the input value is empty. Args: value (str): The string value to parse. Returns: Optional[Path]: A Path object if the value is valid, or None if the value is empty. Raises: ValueError: If the value is not a valid file or directory.

Source code in aiperf/common/config/config_validators.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def parse_file(value: str | None) -> Path | None:
    """
    Parses the given string value and returns a Path object if the value represents
    a valid file or directory. Returns None if the input value is empty.
    Args:
        value (str): The string value to parse.
    Returns:
        Optional[Path]: A Path object if the value is valid, or None if the value is empty.
    Raises:
        ValueError: If the value is not a valid file or directory.
    """

    if not value:
        return None
    elif not isinstance(value, str):
        raise ValueError(f"Expected a string, but got {type(value).__name__}")
    else:
        path = Path(value)
        if path.is_file() or path.is_dir():
            return path
        else:
            raise ValueError(f"'{value}' is not a valid file or directory")

parse_goodput(goodputs)

Parses and validates a dictionary of goodput values, ensuring that all values are non-negative integers or floats, and converts them to floats. Args: goodputs (Dict[str, Any]): A dictionary where keys are target metric names (strings) and values are the corresponding goodput values. Returns: Dict[str, float]: A dictionary with the same keys as the input, but with all values converted to floats. Raises: ValueError: If any value in the input dictionary is not an integer or float, or if any value is negative.

Source code in aiperf/common/config/config_validators.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def parse_goodput(goodputs: dict[str, Any]) -> dict[str, float]:
    """
    Parses and validates a dictionary of goodput values, ensuring that all values
    are non-negative integers or floats, and converts them to floats.
    Args:
        goodputs (Dict[str, Any]): A dictionary where keys are target metric names
            (strings) and values are the corresponding goodput values.
    Returns:
        Dict[str, float]: A dictionary with the same keys as the input, but with
            all values converted to floats.
    Raises:
        ValueError: If any value in the input dictionary is not an integer or float,
            or if any value is negative.
    """

    constraints = {}
    for target_metric, target_value in goodputs.items():
        if isinstance(target_value, (int | float)):
            if target_value < 0:
                raise ValueError(
                    f"User Config: Goodput values must be non-negative ({target_metric}: {target_value})"
                )

            constraints[target_metric] = float(target_value)
        else:
            raise ValueError("User Config: Goodput values must be integers or floats")

    return constraints

parse_service_types(input)

Parses the input to ensure it is a set of service types. Will replace hyphens with underscores for user convenience.

Source code in aiperf/common/config/config_validators.py
73
74
75
76
77
78
79
80
81
82
def parse_service_types(input: Any | None) -> set[ServiceType] | None:
    """Parses the input to ensure it is a set of service types.
    Will replace hyphens with underscores for user convenience."""
    if input is None:
        return None

    return {
        ServiceType(service_type.replace("-", "_"))
        for service_type in parse_str_or_csv_list(input)
    }

parse_str_or_csv_list(input)

Parses the input to ensure it is either a string or a list. If the input is a string, it splits the string by commas and trims any whitespace around each element, returning the result as a list. If the input is already a list, it will split each item by commas and trim any whitespace around each element, returning the combined result as a list. If the input is neither a string nor a list, a ValueError is raised.

[1, 2, 3] -> [1, 2, 3] "1,2,3" -> ["1", "2", "3"]["1,2,3", "4,5,6"] -> ["1", "2", "3", "4", "5", "6"]["1,2,3", 4, 5] -> ["1", "2", "3", 4, 5]

Source code in aiperf/common/config/config_validators.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def parse_str_or_csv_list(input: Any) -> list[Any]:
    """
    Parses the input to ensure it is either a string or a list. If the input is a string,
    it splits the string by commas and trims any whitespace around each element, returning
    the result as a list. If the input is already a list, it will split each item by commas
    and trim any whitespace around each element, returning the combined result as a list.
    If the input is neither a string nor a list, a ValueError is raised.

    [1, 2, 3] -> [1, 2, 3]
    "1,2,3" -> ["1", "2", "3"]
    ["1,2,3", "4,5,6"] -> ["1", "2", "3", "4", "5", "6"]
    ["1,2,3", 4, 5] -> ["1", "2", "3", 4, 5]
    """
    if isinstance(input, str):
        output = [item.strip() for item in input.split(",")]
    elif isinstance(input, list):
        output = []
        for item in input:
            if isinstance(item, str):
                output.extend([token.strip() for token in item.split(",")])
            else:
                output.append(item)
    else:
        raise ValueError(f"User Config: {input} - must be a string or list")

    return output

parse_str_or_dict(input)

Parses the input to ensure it is a dictionary.

  • If the input is a string:
    • If the string starts with a '{', it is parsed as a JSON string.
    • Otherwise, it splits the string by commas and then for each item, it splits the item by colons into key and value, trims any whitespace.
  • If the input is already a dictionary, it is returned as-is.
  • If the input is a list, it is converted to a dictionary by splitting each string by colons into key and value, trims any whitespace.
  • Otherwise, a ValueError is raised.

Parameters:

Name Type Description Default
input Any

The input to be parsed. Expected to be a string, list, or dictionary.

required

Returns: dict[str, Any]: A dictionary derived from the input. Raises: ValueError: If the input is neither a string, list, nor dictionary, or if the parsing fails.

Source code in aiperf/common/config/config_validators.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def parse_str_or_dict(input: Any | None) -> dict[str, Any] | None:
    """
    Parses the input to ensure it is a dictionary.

    - If the input is a string:
        - If the string starts with a '{', it is parsed as a JSON string.
        - Otherwise, it splits the string by commas and then for each item, it splits the item by colons
        into key and value, trims any whitespace.
    - If the input is already a dictionary, it is returned as-is.
    - If the input is a list, it is converted to a dictionary by splitting each string by colons
    into key and value, trims any whitespace.
    - Otherwise, a ValueError is raised.

    Args:
        input (Any): The input to be parsed. Expected to be a string, list, or dictionary.
    Returns:
        dict[str, Any]: A dictionary derived from the input.
    Raises:
        ValueError: If the input is neither a string, list, nor dictionary, or if the parsing fails.
    """

    if input is None:
        return None

    if isinstance(input, dict):
        return input

    if isinstance(input, list):
        return {
            key.strip(): value.strip()
            for item in input
            for key, value in [item.split(":")]
        }

    if isinstance(input, str):
        if input.startswith("{"):
            try:
                return json.loads(input)
            except json.JSONDecodeError as e:
                raise ValueError(
                    f"User Config: {input} - must be a valid JSON string"
                ) from e
        else:
            return {
                key.strip(): value.strip()
                for item in input.split(",")
                for key, value in [item.split(":")]
            }

    raise ValueError(f"User Config: {input} - must be a valid string, list, or dict")

parse_str_or_list(input)

Parses the input to ensure it is either a string or a list. If the input is a string, it splits the string by commas and trims any whitespace around each element, returning the result as a list. If the input is already a list, it is returned as-is. If the input is neither a string nor a list, a ValueError is raised. Args: input (Any): The input to be parsed. Expected to be a string or a list. Returns: list: A list of strings derived from the input. Raises: ValueError: If the input is neither a string nor a list.

Source code in aiperf/common/config/config_validators.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def parse_str_or_list(input: Any) -> list[Any]:
    """
    Parses the input to ensure it is either a string or a list. If the input is a string,
    it splits the string by commas and trims any whitespace around each element, returning
    the result as a list. If the input is already a list, it is returned as-is. If the input
    is neither a string nor a list, a ValueError is raised.
    Args:
        input (Any): The input to be parsed. Expected to be a string or a list.
    Returns:
        list: A list of strings derived from the input.
    Raises:
        ValueError: If the input is neither a string nor a list.
    """
    if isinstance(input, str):
        output = [item.strip() for item in input.split(",")]
    elif isinstance(input, list):
        # TODO: When using cyclopts, the values are already lists, so we have to split them by commas.
        output = []
        for item in input:
            if isinstance(item, str):
                output.extend([token.strip() for token in item.split(",")])
            else:
                output.append(item)
    else:
        raise ValueError(f"User Config: {input} - must be a string or list")

    return output

parse_str_or_list_of_positive_values(input)

Parses the input to ensure it is a list of positive integers or floats. This function first converts the input into a list using parse_str_or_list. It then validates that each value in the list is either an integer or a float and that all values are strictly greater than zero. If any value fails this validation, a ValueError is raised. Args: input (Any): The input to be parsed. It can be a string or a list. Returns: List[Any]: A list of positive integers or floats. Raises: ValueError: If any value in the parsed list is not a positive integer or float.

Source code in aiperf/common/config/config_validators.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def parse_str_or_list_of_positive_values(input: Any) -> list[Any]:
    """
    Parses the input to ensure it is a list of positive integers or floats.
    This function first converts the input into a list using `parse_str_or_list`.
    It then validates that each value in the list is either an integer or a float
    and that all values are strictly greater than zero. If any value fails this
    validation, a `ValueError` is raised.
    Args:
        input (Any): The input to be parsed. It can be a string or a list.
    Returns:
        List[Any]: A list of positive integers or floats.
    Raises:
        ValueError: If any value in the parsed list is not a positive integer or float.
    """

    output = parse_str_or_list(input)

    try:
        output = [
            float(x) if "." in str(x) or "e" in str(x).lower() else int(x)
            for x in output
        ]
    except ValueError as e:
        raise ValueError(f"User Config: {output} - all values must be numeric") from e

    if not all(isinstance(x, (int | float)) and x > 0 for x in output):
        raise ValueError(f"User Config: {output} - all values must be positive numbers")

    return output

aiperf.common.config.conversation_config

ConversationConfig

Bases: BaseConfig

A configuration class for defining conversations related settings.

Source code in aiperf/common/config/conversation_config.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class ConversationConfig(BaseConfig):
    """
    A configuration class for defining conversations related settings.
    """

    _CLI_GROUP = Groups.CONVERSATION_INPUT

    num: Annotated[
        int,
        Field(
            ge=1,
            description="The total number of unique conversations to generate.\n"
            "Each conversation represents a single request session between client and server.\n"
            "Supported on synthetic mode only and conversations will be reused until benchmarking is complete.",
        ),
        Parameter(
            name=(
                "--conversation-num",
                "--num-conversations",
                "--num-sessions",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = ConversationDefaults.NUM

    turn: TurnConfig = TurnConfig()

TurnConfig

Bases: BaseConfig

A configuration class for defining turn related settings in a conversation.

Source code in aiperf/common/config/conversation_config.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class TurnConfig(BaseConfig):
    """
    A configuration class for defining turn related settings in a conversation.
    """

    _CLI_GROUP = Groups.CONVERSATION_INPUT

    mean: Annotated[
        int,
        Field(
            ge=1,
            description="The mean number of turns within a conversation.",
        ),
        Parameter(
            name=(
                "--conversation-turn-mean",
                "--session-turns-mean",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = TurnDefaults.MEAN

    stddev: Annotated[
        int,
        Field(
            ge=0,
            description="The standard deviation of the number of turns within a conversation.",
        ),
        Parameter(
            name=(
                "--conversation-turn-stddev",
                "--session-turns-stddev",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = TurnDefaults.STDDEV

    delay: TurnDelayConfig = TurnDelayConfig()

TurnDelayConfig

Bases: BaseConfig

A configuration class for defining turn delay related settings.

Source code in aiperf/common/config/conversation_config.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class TurnDelayConfig(BaseConfig):
    """
    A configuration class for defining turn delay related settings.
    """

    _CLI_GROUP = Groups.CONVERSATION_INPUT

    mean: Annotated[
        float,
        Field(
            ge=0,
            description="The mean delay between turns within a conversation in milliseconds.",
        ),
        Parameter(
            name=(
                "--conversation-turn-delay-mean",
                "--session-turn-delay-mean",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = TurnDelayDefaults.MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of the delay between turns \n"
            "within a conversation in milliseconds.",
        ),
        Parameter(
            name=(
                "--conversation-turn-delay-stddev",
                "--session-turn-delay-stddev",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = TurnDelayDefaults.STDDEV

    ratio: Annotated[
        float,
        Field(
            ge=0,
            description="A ratio to scale multi-turn delays.",
        ),
        Parameter(
            name=(
                "--conversation-turn-delay-ratio",
                "--session-delay-ratio",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = TurnDelayDefaults.RATIO

aiperf.common.config.endpoint_config

EndpointConfig

Bases: BaseConfig

A configuration class for defining endpoint related settings.

Source code in aiperf/common/config/endpoint_config.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class EndpointConfig(BaseConfig):
    """
    A configuration class for defining endpoint related settings.
    """

    _CLI_GROUP = Groups.ENDPOINT

    @model_validator(mode="after")
    def validate_streaming(self) -> Self:
        if not self.type.supports_streaming:
            _logger.warning(
                f"Streaming is not supported for --endpoint-type {self.type}, setting streaming to False"
            )
            self.streaming = False
        return self

    model_names: Annotated[
        list[str],
        Field(
            ...,
            description="Model name(s) to be benchmarked. Can be a comma-separated list or a single model name.",
        ),
        BeforeValidator(parse_str_or_list),
        Parameter(
            name=(
                "--model-names",
                "--model",  # GenAI-Perf
                "-m",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ]

    model_selection_strategy: Annotated[
        ModelSelectionStrategy,
        Field(
            description="When multiple models are specified, this is how a specific model should be assigned to a prompt.\n"
            "round_robin: nth prompt in the list gets assigned to n-mod len(models).\n"
            "random: assignment is uniformly random",
        ),
        Parameter(
            name=(
                "--model-selection-strategy",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = EndpointDefaults.MODEL_SELECTION_STRATEGY

    custom_endpoint: Annotated[
        str | None,
        Field(
            description="Set a custom endpoint that differs from the OpenAI defaults.",
        ),
        Parameter(
            name=(
                "--custom-endpoint",
                "--endpoint",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = EndpointDefaults.CUSTOM_ENDPOINT

    type: Annotated[
        EndpointType,
        Field(
            description="The type to send requests to on the server.",
        ),
        Parameter(
            name=(
                "--endpoint-type",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = EndpointDefaults.TYPE

    streaming: Annotated[
        bool,
        Field(
            description="An option to enable the use of the streaming API.",
        ),
        Parameter(
            name=(
                "--streaming",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = EndpointDefaults.STREAMING

    server_metrics_urls: Annotated[
        list[str],
        Field(
            description="The list of Triton server metrics URLs.\n"
            "These are used for Telemetry metric reporting with Triton.",
        ),
        BeforeValidator(parse_str_or_list),
        Parameter(
            name=(
                "--server-metrics-urls",  # GenAI-Perf
                "--server-metrics-url",  # GenAI-Perf
            ),
            parse=False,  # TODO: Not yet supported
            group=_CLI_GROUP,
        ),
    ] = EndpointDefaults.SERVER_METRICS_URLS

    url: Annotated[
        str,
        Field(
            description="URL of the endpoint to target for benchmarking.",
        ),
        Parameter(
            name=(
                "--url",  # GenAI-Perf
                "-u",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = EndpointDefaults.URL

    grpc_method: Annotated[
        str,
        Field(
            description="A fully-qualified gRPC method name in "
            "'<package>.<service>/<method>' format.\n"
            "The option is only supported by dynamic gRPC service kind and is\n"
            "required to identify the RPC to use when sending requests to the server.",
        ),
        Parameter(
            name=(
                "--grpc-method",  # GenAI-Perf
            ),
            parse=False,  # TODO: Not yet supported
            group=_CLI_GROUP,
        ),
    ] = EndpointDefaults.GRPC_METHOD

    # NEW AIPerf Option
    timeout_seconds: Annotated[
        float,
        Field(
            description="The timeout in floating points seconds for each request to the endpoint.",
        ),
        Parameter(
            name=("--request-timeout-seconds"),
            group=_CLI_GROUP,
        ),
    ] = EndpointDefaults.TIMEOUT

    # NEW AIPerf Option
    api_key: Annotated[
        str | None,
        Field(
            description="The API key to use for the endpoint. If provided, it will be sent with every request as "
            "a header: `Authorization: Bearer <api_key>`.",
        ),
        Parameter(
            name=("--api-key"),
            group=_CLI_GROUP,
        ),
    ] = EndpointDefaults.API_KEY

aiperf.common.config.groups

Groups

Groups for the CLI.

NOTE: The order of these groups are the order they will be displayed in the help text.

Source code in aiperf/common/config/groups.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Groups:
    """Groups for the CLI.

    NOTE: The order of these groups are the order they will be displayed in the help text.
    """

    ENDPOINT = Group.create_ordered("Endpoint")
    INPUT = Group.create_ordered("Input")
    OUTPUT = Group.create_ordered("Output")
    TOKENIZER = Group.create_ordered("Tokenizer")
    LOAD_GENERATOR = Group.create_ordered("Load Generator")
    CONVERSATION_INPUT = Group.create_ordered("Conversation Input")
    INPUT_SEQUENCE_LENGTH = Group.create_ordered("Input Sequence Length (ISL)")
    OUTPUT_SEQUENCE_LENGTH = Group.create_ordered("Output Sequence Length (OSL)")
    PROMPT = Group.create_ordered("Prompt")
    PREFIX_PROMPT = Group.create_ordered("Prefix Prompt")
    AUDIO_INPUT = Group.create_ordered("Audio Input")
    IMAGE_INPUT = Group.create_ordered("Image Input")
    MEASUREMENT = Group.create_ordered("Measurement")
    SERVICE = Group.create_ordered("Service")
    WORKERS = Group.create_ordered("Workers")
    DEVELOPER = Group.create_ordered("Developer")

aiperf.common.config.image_config

ImageConfig

Bases: BaseConfig

A configuration class for defining image related settings.

Source code in aiperf/common/config/image_config.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class ImageConfig(BaseConfig):
    """
    A configuration class for defining image related settings.
    """

    _CLI_GROUP = Groups.IMAGE_INPUT

    width: ImageWidthConfig = ImageWidthConfig()
    height: ImageHeightConfig = ImageHeightConfig()
    batch_size: Annotated[
        int,
        Field(
            ge=0,
            description="The image batch size of the requests AIPerf should send.\n"
            "This is currently supported with the image retrieval endpoint type.",
        ),
        Parameter(
            name=(
                "--image-batch-size",
                "--batch-size-image",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = ImageDefaults.BATCH_SIZE

    format: Annotated[
        ImageFormat,
        Field(
            description="The compression format of the images.",
        ),
        Parameter(
            name=(
                "--image-format",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = ImageDefaults.FORMAT

ImageHeightConfig

Bases: BaseConfig

A configuration class for defining image height related settings.

Source code in aiperf/common/config/image_config.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class ImageHeightConfig(BaseConfig):
    """
    A configuration class for defining image height related settings.
    """

    _CLI_GROUP = Groups.IMAGE_INPUT

    mean: Annotated[
        float,
        Field(
            ge=0,
            description="The mean height of images when generating synthetic image data.",
        ),
        Parameter(
            name=(
                "--image-height-mean",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = ImageDefaults.HEIGHT_MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of height of images when generating synthetic image data.",
        ),
        Parameter(
            name=(
                "--image-height-stddev",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = ImageDefaults.HEIGHT_STDDEV

ImageWidthConfig

Bases: BaseConfig

A configuration class for defining image width related settings.

Source code in aiperf/common/config/image_config.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class ImageWidthConfig(BaseConfig):
    """
    A configuration class for defining image width related settings.
    """

    _CLI_GROUP = Groups.IMAGE_INPUT

    mean: Annotated[
        float,
        Field(
            ge=0,
            description="The mean width of images when generating synthetic image data.",
        ),
        Parameter(
            name=(
                "--image-width-mean",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = ImageDefaults.WIDTH_MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of width of images when generating synthetic image data.",
        ),
        Parameter(
            name=(
                "--image-width-stddev",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = ImageDefaults.WIDTH_STDDEV

aiperf.common.config.input_config

InputConfig

Bases: BaseConfig

A configuration class for defining input related settings.

Source code in aiperf/common/config/input_config.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class InputConfig(BaseConfig):
    """
    A configuration class for defining input related settings.
    """

    _CLI_GROUP = Groups.INPUT

    @model_validator(mode="after")
    def validate_fixed_schedule(self) -> Self:
        """Validate the fixed schedule configuration."""
        if self.fixed_schedule and self.file is None:
            raise ValueError("Fixed schedule requires a file to be provided")
        if self.file is not None:
            self.fixed_schedule = True
            logger.debug("Fixed schedule is enabled because file is provided")
        return self

    extra: Annotated[
        dict[str, Any] | None,
        Field(
            description="Provide additional inputs to include with every request.\n"
            "Inputs should be in an 'input_name:value' format.",
        ),
        Parameter(
            name=(
                "--extra-inputs",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
        BeforeValidator(parse_str_or_dict),
    ] = InputDefaults.EXTRA

    goodput: Annotated[
        dict[str, Any],
        Field(
            description="An option to provide constraints in order to compute goodput.\n"
            "Specify goodput constraints as 'key:value' pairs,\n"
            "where the key is a valid metric name, and the value is a number representing\n"
            "either milliseconds or a throughput value per second.\n"
            "For example: request_latency:300,output_token_throughput_per_user:600",
        ),
        Parameter(
            name=(
                "--goodput",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
        BeforeValidator(parse_goodput),
    ] = InputDefaults.GOODPUT

    headers: Annotated[
        dict[str, str] | None,
        Field(
            description="Adds a custom header to the requests.\n"
            "Headers must be specified as 'Header:Value' pairs.",
        ),
        BeforeValidator(parse_str_or_dict),
        Parameter(
            name=(
                "--header",  # GenAI-Perf
                "-H",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = InputDefaults.HEADERS

    file: Annotated[
        Any,
        Field(
            description="The file or directory path that contains the dataset to use for profiling.\n"
            "This parameter is used in conjunction with the `custom_dataset_type` parameter\n"
            "to support different types of user provided datasets.",
        ),
        BeforeValidator(parse_file),
        Parameter(
            name=(
                "--input-file",  # GenAI-Perf,
            ),
            group=_CLI_GROUP,
        ),
    ] = InputDefaults.FILE

    fixed_schedule: Annotated[
        bool,
        Field(
            description="Specifies to run a fixed schedule of requests. This is normally inferred from the --input-file parameter, but can be set manually here."
        ),
        Parameter(
            name=(
                "--fixed-schedule",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = InputDefaults.FIXED_SCHEDULE

    # NEW AIPerf Option
    custom_dataset_type: Annotated[
        CustomDatasetType,
        Field(
            description="The type of custom dataset to use.\n"
            "This parameter is used in conjunction with the --file parameter.",
        ),
        Parameter(
            name=("--custom-dataset-type"),
            group=_CLI_GROUP,
        ),
    ] = InputDefaults.CUSTOM_DATASET_TYPE

    random_seed: Annotated[
        int | None,
        Field(
            default=None,
            description="The seed used to generate random values.\n"
            "Set to some value to make the synthetic data generation deterministic.\n"
            "It will use system default if not provided.",
        ),
        Parameter(
            name=(
                "--random-seed",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = InputDefaults.RANDOM_SEED

    audio: AudioConfig = AudioConfig()
    image: ImageConfig = ImageConfig()
    prompt: PromptConfig = PromptConfig()
    conversation: ConversationConfig = ConversationConfig()

validate_fixed_schedule()

Validate the fixed schedule configuration.

Source code in aiperf/common/config/input_config.py
35
36
37
38
39
40
41
42
43
@model_validator(mode="after")
def validate_fixed_schedule(self) -> Self:
    """Validate the fixed schedule configuration."""
    if self.fixed_schedule and self.file is None:
        raise ValueError("Fixed schedule requires a file to be provided")
    if self.file is not None:
        self.fixed_schedule = True
        logger.debug("Fixed schedule is enabled because file is provided")
    return self

aiperf.common.config.loader

load_service_config()

Load the service configuration.

Source code in aiperf/common/config/loader.py
 7
 8
 9
10
def load_service_config() -> ServiceConfig:
    """Load the service configuration."""
    # TODO: implement
    return ServiceConfig()

load_user_config()

Load the user configuration.

Source code in aiperf/common/config/loader.py
13
14
15
16
def load_user_config() -> UserConfig:
    """Load the user configuration."""
    # TODO: implement
    raise NotImplementedError("User configuration is not implemented")

aiperf.common.config.loadgen_config

LoadGeneratorConfig

Bases: BaseConfig

A configuration class for defining top-level load generator settings.

Source code in aiperf/common/config/loadgen_config.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class LoadGeneratorConfig(BaseConfig):
    """
    A configuration class for defining top-level load generator settings.
    """

    _CLI_GROUP = Groups.LOAD_GENERATOR

    # TODO: Potentially add a validator to ensure that the concurrency is not greater than the request count
    concurrency: Annotated[
        int,
        Field(
            ge=1,
            description="The concurrency value to benchmark.",
        ),
        Parameter(
            name=(
                "--concurrency",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = LoadGeneratorDefaults.CONCURRENCY

    request_rate: Annotated[
        float | None,
        Field(
            gt=0,
            description="Sets the request rate for the load generated by AIPerf. Unit: requests/second",
        ),
        Parameter(
            name=(
                "--request-rate",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = LoadGeneratorDefaults.REQUEST_RATE

    # NEW AIPerf Option
    request_rate_mode: Annotated[
        RequestRateMode,
        Field(
            description="Sets the request rate mode for the load generated by AIPerf. Valid values: constant, poisson.\n"
            "constant: Generate requests at a fixed rate.\n"
            "poisson: Generate requests using a poisson distribution.\n"
            f"The default is {LoadGeneratorDefaults.REQUEST_RATE_MODE}.",
        ),
        Parameter(
            name=("--request-rate-mode"),
            group=_CLI_GROUP,
        ),
    ] = LoadGeneratorDefaults.REQUEST_RATE_MODE

    request_count: Annotated[
        int,
        Field(
            ge=1,
            description="The number of requests to use for measurement.",
        ),
        Parameter(
            name=(
                "--request-count",  # GenAI-Perf
                "--num-requests",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = LoadGeneratorDefaults.REQUEST_COUNT

    warmup_request_count: Annotated[
        int,
        Field(
            ge=0,
            description="The number of warmup requests to send before benchmarking.",
        ),
        Parameter(
            name=(
                "--warmup-request-count",  # GenAI-Perf
                "--num-warmup-requests",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = LoadGeneratorDefaults.WARMUP_REQUEST_COUNT

aiperf.common.config.measurement_config

MeasurementConfig

Bases: BaseConfig

A configuration class for defining top-level measurement settings.

Source code in aiperf/common/config/measurement_config.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class MeasurementConfig(BaseConfig):
    """
    A configuration class for defining top-level measurement settings.
    """

    _CLI_GROUP = Groups.MEASUREMENT

    # TODO: Not implemented yet
    measurement_interval: Annotated[
        float,
        Field(
            ge=1,
            le=1_000_000,
            description="The time interval used for each measurement in milliseconds. "
            "AIPerf will sample a time interval specified and take "
            "measurement over the requests completed within that time interval. "
            "When using the default stability percentage, AIPerf will benchmark  "
            "for 3*(measurement_interval) milliseconds.",
        ),
        Parameter(
            name=(
                "--measurement-interval-ms",
                "--measurement-interval",  # GenAI-Perf
                "-p",  # GenAI-Perf
            ),
            parse=False,  # TODO: Not yet supported
            group=_CLI_GROUP,
        ),
    ] = MeasurementDefaults.MEASUREMENT_INTERVAL

    # TODO: Not implemented yet
    stability_percentage: Annotated[
        float,
        Field(
            gt=0.0,
            lt=1.0,
            description="The allowed variation in latency measurements when determining if a result is stable.\n"
            "The measurement is considered as stable if the ratio of max / min\n"
            "from the recent 3 measurements is within (stability percentage)\n"
            "in terms of both infer per second and latency.",
        ),
        Parameter(
            name=(
                "--stability-percentage",  # GenAI-Perf
                "-s",  # GenAI-Perf
            ),
            parse=False,  # TODO: Not yet supported
            group=_CLI_GROUP,
        ),
    ] = MeasurementDefaults.STABILITY_PERCENTAGE

aiperf.common.config.output_config

OutputConfig

Bases: BaseConfig

A configuration class for defining output related settings.

Source code in aiperf/common/config/output_config.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class OutputConfig(BaseConfig):
    """
    A configuration class for defining output related settings.
    """

    _CLI_GROUP = Groups.OUTPUT

    artifact_directory: Annotated[
        Path,
        Field(
            description="The directory to store all the (output) artifacts generated by AIPerf.",
        ),
        Parameter(
            name=(
                "--output-artifact-dir",
                "--artifact-dir",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = OutputDefaults.ARTIFACT_DIRECTORY

    show_internal_metrics: Annotated[
        bool,
        Field(
            description="Whether to show internal and hidden metrics in the output",
        ),
        Parameter(
            name=("--show-internal-metrics"),
            group=_CLI_GROUP,
        ),
    ] = OutputDefaults.SHOW_INTERNAL_METRICS

aiperf.common.config.prompt_config

InputTokensConfig

Bases: BaseConfig

A configuration class for defining input token related settings.

Source code in aiperf/common/config/prompt_config.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class InputTokensConfig(BaseConfig):
    """
    A configuration class for defining input token related settings.
    """

    _CLI_GROUP = Groups.INPUT_SEQUENCE_LENGTH

    mean: Annotated[
        int,
        Field(
            ge=0,
            description="The mean of number of tokens in the generated prompts when using synthetic data.",
        ),
        Parameter(
            name=(
                "--prompt-input-tokens-mean",
                "--synthetic-input-tokens-mean",  # GenAI-Perf
                "--isl",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = InputTokensDefaults.MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of number of tokens in the generated prompts when using synthetic data.",
        ),
        Parameter(
            name=(
                "--prompt-input-tokens-stddev",
                "--synthetic-input-tokens-stddev",  # GenAI-Perf
                "--isl-stddev",
            ),
            group=_CLI_GROUP,
        ),
    ] = InputTokensDefaults.STDDEV

    # NEW AIPerf Option
    block_size: Annotated[
        int,
        Field(
            default=512,
            description="The block size of the prompt.",
        ),
        Parameter(
            name=(
                "--prompt-input-tokens-block-size",
                "--synthetic-input-tokens-block-size",
                "--isl-block-size",
            ),
            group=_CLI_GROUP,
        ),
    ] = InputTokensDefaults.BLOCK_SIZE

OutputTokensConfig

Bases: BaseConfig

A configuration class for defining output token related settings.

Source code in aiperf/common/config/prompt_config.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class OutputTokensConfig(BaseConfig):
    """
    A configuration class for defining output token related settings.
    """

    _CLI_GROUP = Groups.OUTPUT_SEQUENCE_LENGTH

    mean: Annotated[
        int,
        Field(
            ge=0,
            description="The mean number of tokens in each output.",
        ),
        Parameter(
            name=(
                "--prompt-output-tokens-mean",
                "--output-tokens-mean",  # GenAI-Perf
                "--osl",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = OutputTokensDefaults.MEAN

    deterministic: Annotated[
        bool,
        Field(
            description=(
                "This can be set to improve the precision of the mean by setting the\n"
                "minimum number of tokens equal to the requested number of tokens.\n"
                "This is currently supported with Triton."
            ),
        ),
        Parameter(
            name=(
                "--prompt-output-tokens-deterministic",
                "--output-tokens-mean-deterministic",  # GenAI-Perf
                "--osl-deterministic",
            ),
            group=_CLI_GROUP,
        ),
    ] = OutputTokensDefaults.DETERMINISTIC

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of the number of tokens in each output.",
        ),
        Parameter(
            name=(
                "--prompt-output-tokens-stddev",
                "--output-tokens-stddev",  # GenAI-Perf
                "--osl-stddev",
            ),
            group=_CLI_GROUP,
        ),
    ] = OutputTokensDefaults.STDDEV

PrefixPromptConfig

Bases: BaseConfig

A configuration class for defining prefix prompt related settings.

Source code in aiperf/common/config/prompt_config.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class PrefixPromptConfig(BaseConfig):
    """
    A configuration class for defining prefix prompt related settings.
    """

    _CLI_GROUP = Groups.PREFIX_PROMPT

    pool_size: Annotated[
        int,
        Field(
            ge=0,
            description=(
                "The total size of the prefix prompt pool to select prefixes from.\n"
                "If this value is not zero, these are prompts that are prepended to input prompts.\n"
                "This is useful for benchmarking models that use a K-V cache."
            ),
        ),
        Parameter(
            name=(
                "--prompt-prefix-pool-size",
                "--prefix-prompt-pool-size",
                "--num-prefix-prompts",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = PrefixPromptDefaults.POOL_SIZE

    length: Annotated[
        int,
        Field(
            ge=0,
            description=(
                "The number of tokens in each prefix prompt.\n"
                'This is only used if "num" is greater than zero.\n'
                "Note that due to the prefix and user prompts being concatenated,\n"
                "the number of tokens in the final prompt may be off by one."
            ),
        ),
        Parameter(
            name=(
                "--prompt-prefix-length",
                "--prefix-prompt-length",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = PrefixPromptDefaults.LENGTH

PromptConfig

Bases: BaseConfig

A configuration class for defining prompt related settings.

Source code in aiperf/common/config/prompt_config.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class PromptConfig(BaseConfig):
    """
    A configuration class for defining prompt related settings.
    """

    _CLI_GROUP = Groups.PROMPT

    batch_size: Annotated[
        int,
        Field(
            description="The batch size of text requests AIPerf should send.\n"
            "This is currently supported with the embeddings and rankings endpoint types",
        ),
        Parameter(
            name=(
                "--prompt-batch-size",
                "--batch-size-text",  # GenAI-Perf
                "--batch-size",  # GenAI-Perf
                "-b",  # GenAI-Perf
            ),
            group=_CLI_GROUP,
        ),
    ] = PromptDefaults.BATCH_SIZE

    input_tokens: InputTokensConfig = InputTokensConfig()
    output_tokens: OutputTokensConfig = OutputTokensConfig()
    prefix_prompt: PrefixPromptConfig = PrefixPromptConfig()

aiperf.common.config.service_config

ServiceConfig

Bases: BaseSettings

Base configuration for all services. It will be provided to all services during their init function.

Source code in aiperf/common/config/service_config.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
class ServiceConfig(BaseSettings):
    """Base configuration for all services. It will be provided to all services during their __init__ function."""

    model_config = SettingsConfigDict(
        env_prefix="AIPERF_",
        env_file=".env",
        env_file_encoding="utf-8",
        extra="allow",
    )

    _CLI_GROUP = Groups.SERVICE

    @model_validator(mode="after")
    def validate_log_level_from_verbose_flags(self) -> Self:
        """Set log level based on verbose flags."""
        if self.extra_verbose:
            self.log_level = AIPerfLogLevel.TRACE
        elif self.verbose:
            self.log_level = AIPerfLogLevel.DEBUG
        return self

    @model_validator(mode="after")
    def validate_comm_config(self) -> Self:
        """Initialize the comm_config if it is not provided, based on the comm_backend."""
        if self.comm_config is None:
            if self.comm_backend == CommunicationBackend.ZMQ_IPC:
                self.comm_config = ZMQIPCConfig()
            elif self.comm_backend == CommunicationBackend.ZMQ_TCP:
                self.comm_config = ZMQTCPConfig()
            else:
                raise ValueError(f"Invalid communication backend: {self.comm_backend}")
        return self

    service_run_type: Annotated[
        ServiceRunType,
        Field(
            description="Type of service run (process, k8s)",
        ),
        Parameter(
            name=("--service-run-type", "--run-type"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.SERVICE_RUN_TYPE

    comm_backend: Annotated[
        CommunicationBackend,
        Field(
            description="Communication backend to use",
        ),
        Parameter(
            name=("--comm-backend"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.COMM_BACKEND

    comm_config: Annotated[
        BaseZMQCommunicationConfig | None,
        Field(
            description="Communication configuration",
        ),
        Parameter(
            parse=False,  # This is not supported via CLI
        ),
    ] = ServiceDefaults.COMM_CONFIG

    heartbeat_timeout: Annotated[
        float,
        Field(
            description="Time in seconds after which a service is considered dead if no "
            "heartbeat received",
        ),
        Parameter(
            name=("--heartbeat-timeout"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.HEARTBEAT_TIMEOUT

    registration_timeout: Annotated[
        float,
        Field(
            description="Time in seconds to wait for all required services to register",
        ),
        Parameter(
            name=("--registration-timeout"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.REGISTRATION_TIMEOUT

    command_timeout: Annotated[
        float,
        Field(
            description="Default timeout for command responses",
        ),
        Parameter(
            name=("--command-timeout", "--command-timeout-seconds"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.COMMAND_TIMEOUT

    heartbeat_interval_seconds: Annotated[
        float,
        Field(
            description="Interval in seconds between heartbeat messages",
        ),
        Parameter(
            name=("--heartbeat-interval-seconds", "--heartbeat-interval"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.HEARTBEAT_INTERVAL_SECONDS

    workers: Annotated[
        WorkersConfig,
        Field(
            description="Worker configuration",
        ),
    ] = WorkersConfig()

    log_level: Annotated[
        AIPerfLogLevel,
        Field(
            description="Logging level",
        ),
        Parameter(
            name=("--log-level"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.LOG_LEVEL

    verbose: Annotated[
        bool,
        Field(
            description="Equivalent to --log-level DEBUG. Enables more verbose logging output, but lacks some raw message logging.",
            json_schema_extra={ADD_TO_TEMPLATE: False},
        ),
        Parameter(
            name=("--verbose", "-v"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.VERBOSE

    extra_verbose: Annotated[
        bool,
        Field(
            description="Equivalent to --log-level TRACE. Enables the most verbose logging output possible.",
            json_schema_extra={ADD_TO_TEMPLATE: False},
        ),
        Parameter(
            name=("--extra-verbose", "-vv"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.EXTRA_VERBOSE

    disable_ui: Annotated[
        bool,
        Field(
            description="Disable the UI (prints progress to the console as log messages). This is equivalent to --ui-type none.",
        ),
        Parameter(
            name=("--disable-ui"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.DISABLE_UI

    enable_uvloop: Annotated[
        bool,
        Field(
            description="Enable the use of uvloop instead of the default asyncio event loop",
        ),
        Parameter(
            name=("--enable-uvloop"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.ENABLE_UVLOOP

    # TODO: Potentially auto-scale this in the future.
    record_processor_service_count: Annotated[
        int | None,
        Field(
            ge=1,
            description="Number of services to spawn for processing records. The higher the request rate, the more services "
            "should be spawned in order to keep up with the incoming records. If not specified, the number of services will be "
            "automatically determined based on the worker count.",
        ),
        Parameter(
            name=("--record-processor-service-count", "--record-processors"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.RECORD_PROCESSOR_SERVICE_COUNT

    progress_report_interval: Annotated[
        float,
        Field(
            description="Interval in seconds to report progress. This is used to report the progress of the profile to the user.",
        ),
        Parameter(
            name=("--progress-report-interval-seconds", "--progress-report-interval"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.PROGRESS_REPORT_INTERVAL

    enable_yappi: Annotated[
        bool,
        Field(
            description="*[Developer use only]* Enable yappi profiling (Yet Another Python Profiler) to profile AIPerf's internal python code. "
            "This can be used in the development of AIPerf in order to find performance bottlenecks across the various services. "
            "The output '.prof' files can be viewed with snakeviz. Requires yappi and snakeviz to be installed. "
            "Run 'pip install yappi snakeviz' to install them.",
        ),
        Parameter(
            name=("--enable-yappi-profiling"),
            group=_CLI_GROUP,
        ),
    ] = ServiceDefaults.ENABLE_YAPPI

    debug_services: Annotated[
        set[ServiceType] | None,
        Field(
            description="List of services to enable debug logging for. Can be a comma-separated list, a single service type, "
            "or the cli flag can be used multiple times.",
        ),
        Parameter(
            name=("--debug-service", "--debug-services"),
            group=_CLI_GROUP,
        ),
        BeforeValidator(parse_service_types),
    ] = ServiceDefaults.DEBUG_SERVICES

    trace_services: Annotated[
        set[ServiceType] | None,
        Field(
            description="List of services to enable trace logging for. Can be a comma-separated list, a single service type, "
            "or the cli flag can be used multiple times.",
        ),
        Parameter(
            name=("--trace-service", "--trace-services"),
            group=_CLI_GROUP,
        ),
        BeforeValidator(parse_service_types),
    ] = ServiceDefaults.TRACE_SERVICES

validate_comm_config()

Initialize the comm_config if it is not provided, based on the comm_backend.

Source code in aiperf/common/config/service_config.py
49
50
51
52
53
54
55
56
57
58
59
@model_validator(mode="after")
def validate_comm_config(self) -> Self:
    """Initialize the comm_config if it is not provided, based on the comm_backend."""
    if self.comm_config is None:
        if self.comm_backend == CommunicationBackend.ZMQ_IPC:
            self.comm_config = ZMQIPCConfig()
        elif self.comm_backend == CommunicationBackend.ZMQ_TCP:
            self.comm_config = ZMQTCPConfig()
        else:
            raise ValueError(f"Invalid communication backend: {self.comm_backend}")
    return self

validate_log_level_from_verbose_flags()

Set log level based on verbose flags.

Source code in aiperf/common/config/service_config.py
40
41
42
43
44
45
46
47
@model_validator(mode="after")
def validate_log_level_from_verbose_flags(self) -> Self:
    """Set log level based on verbose flags."""
    if self.extra_verbose:
        self.log_level = AIPerfLogLevel.TRACE
    elif self.verbose:
        self.log_level = AIPerfLogLevel.DEBUG
    return self

aiperf.common.config.sweep_config

SweepConfig

Bases: BaseConfig

A sweep of parameters.

Source code in aiperf/common/config/sweep_config.py
 99
100
class SweepConfig(BaseConfig):
    """A sweep of parameters."""

SweepParam

Bases: BaseConfig

A parameter to be swept.

Source code in aiperf/common/config/sweep_config.py
8
9
class SweepParam(BaseConfig):
    """A parameter to be swept."""

aiperf.common.config.tokenizer_config

TokenizerConfig

Bases: BaseConfig

A configuration class for defining tokenizer related settings.

Source code in aiperf/common/config/tokenizer_config.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class TokenizerConfig(BaseConfig):
    """
    A configuration class for defining tokenizer related settings.
    """

    _CLI_GROUP = Groups.TOKENIZER

    name: Annotated[
        str | None,
        Field(
            description=(
                "The HuggingFace tokenizer to use to interpret token metrics "
                "from prompts and responses.\nThe value can be the "
                "name of a tokenizer or the filepath of the tokenizer.\n"
                "The default value is the model name."
            ),
        ),
        Parameter(
            name=("--tokenizer"),
            group=_CLI_GROUP,
        ),
    ] = TokenizerDefaults.NAME

    revision: Annotated[
        str,
        Field(
            description=(
                "The specific model version to use.\n"
                "It can be a branch name, tag name, or commit ID."
            ),
        ),
        Parameter(
            name=("--tokenizer-revision"),
            group=_CLI_GROUP,
        ),
    ] = TokenizerDefaults.REVISION

    trust_remote_code: Annotated[
        bool,
        Field(
            description=(
                "Allows custom tokenizer to be downloaded and executed.\n"
                "This carries security risks and should only be used for repositories you trust.\n"
                "This is only necessary for custom tokenizers stored in HuggingFace Hub."
            ),
        ),
        Parameter(
            name=("--tokenizer-trust-remote-code"),
            group=_CLI_GROUP,
        ),
    ] = TokenizerDefaults.TRUST_REMOTE_CODE

aiperf.common.config.user_config

UserConfig

Bases: BaseConfig

A configuration class for defining top-level user settings.

Source code in aiperf/common/config/user_config.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class UserConfig(BaseConfig):
    """
    A configuration class for defining top-level user settings.
    """

    endpoint: Annotated[
        EndpointConfig,
        Field(
            description="Endpoint configuration",
        ),
    ]

    input: Annotated[
        InputConfig,
        Field(
            description="Input configuration",
        ),
    ] = InputConfig()

    output: Annotated[
        OutputConfig,
        Field(
            description="Output configuration",
        ),
    ] = OutputConfig()

    tokenizer: Annotated[
        TokenizerConfig,
        Field(
            description="Tokenizer configuration",
        ),
    ] = TokenizerConfig()

    loadgen: Annotated[
        LoadGeneratorConfig,
        Field(
            description="Load Generator configuration",
        ),
    ] = LoadGeneratorConfig()

    measurement: Annotated[
        MeasurementConfig,
        Field(
            description="Measurement configuration",
        ),
    ] = MeasurementConfig()

aiperf.common.config.worker_config

WorkersConfig

Bases: BaseConfig

Worker configuration.

Source code in aiperf/common/config/worker_config.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class WorkersConfig(BaseConfig):
    """Worker configuration."""

    _CLI_GROUP = Groups.WORKERS

    min: Annotated[
        int | None,
        Field(
            description="Minimum number of workers to maintain",
        ),
        Parameter(
            name=("--workers-min", "--min-workers"),
            group=_CLI_GROUP,
        ),
    ] = WorkersDefaults.MIN

    max: Annotated[
        int | None,
        Field(
            description="Maximum number of workers to create. If not specified, the number of"
            " workers will be determined by the smaller of (concurrency + 1) and (num CPUs - 1).",
        ),
        Parameter(
            name=("--workers-max", "--max-workers"),
            group=_CLI_GROUP,
        ),
    ] = WorkersDefaults.MAX

    health_check_interval: Annotated[
        float,
        Field(
            description="Interval in seconds to for workers to publish their health status.",
        ),
        Parameter(
            name=("--workers-health-check-interval"),
            group=_CLI_GROUP,
        ),
    ] = WorkersDefaults.HEALTH_CHECK_INTERVAL

aiperf.common.config.zmq_config

BaseZMQCommunicationConfig

Bases: BaseModel, ABC

Configuration for ZMQ communication.

Source code in aiperf/common/config/zmq_config.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class BaseZMQCommunicationConfig(BaseModel, ABC):
    """Configuration for ZMQ communication."""

    # Proxy config options to be overridden by subclasses
    event_bus_proxy_config: ClassVar[BaseZMQProxyConfig]
    dataset_manager_proxy_config: ClassVar[BaseZMQProxyConfig]
    raw_inference_proxy_config: ClassVar[BaseZMQProxyConfig]

    @property
    @abstractmethod
    def records_push_pull_address(self) -> str:
        """Get the inference push/pull address based on protocol configuration."""

    @property
    @abstractmethod
    def credit_drop_address(self) -> str:
        """Get the credit drop address based on protocol configuration."""

    @property
    @abstractmethod
    def credit_return_address(self) -> str:
        """Get the credit return address based on protocol configuration."""

    def get_address(self, address_type: CommAddress) -> str:
        """Get the actual address based on the address type."""
        address_map = {
            CommAddress.EVENT_BUS_PROXY_FRONTEND: self.event_bus_proxy_config.frontend_address,
            CommAddress.EVENT_BUS_PROXY_BACKEND: self.event_bus_proxy_config.backend_address,
            CommAddress.DATASET_MANAGER_PROXY_FRONTEND: self.dataset_manager_proxy_config.frontend_address,
            CommAddress.DATASET_MANAGER_PROXY_BACKEND: self.dataset_manager_proxy_config.backend_address,
            CommAddress.CREDIT_DROP: self.credit_drop_address,
            CommAddress.CREDIT_RETURN: self.credit_return_address,
            CommAddress.RECORDS: self.records_push_pull_address,
            CommAddress.RAW_INFERENCE_PROXY_FRONTEND: self.raw_inference_proxy_config.frontend_address,
            CommAddress.RAW_INFERENCE_PROXY_BACKEND: self.raw_inference_proxy_config.backend_address,
        }

        if address_type not in address_map:
            raise ValueError(f"Invalid address type: {address_type}")

        return address_map[address_type]

credit_drop_address abstractmethod property

Get the credit drop address based on protocol configuration.

credit_return_address abstractmethod property

Get the credit return address based on protocol configuration.

records_push_pull_address abstractmethod property

Get the inference push/pull address based on protocol configuration.

get_address(address_type)

Get the actual address based on the address type.

Source code in aiperf/common/config/zmq_config.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def get_address(self, address_type: CommAddress) -> str:
    """Get the actual address based on the address type."""
    address_map = {
        CommAddress.EVENT_BUS_PROXY_FRONTEND: self.event_bus_proxy_config.frontend_address,
        CommAddress.EVENT_BUS_PROXY_BACKEND: self.event_bus_proxy_config.backend_address,
        CommAddress.DATASET_MANAGER_PROXY_FRONTEND: self.dataset_manager_proxy_config.frontend_address,
        CommAddress.DATASET_MANAGER_PROXY_BACKEND: self.dataset_manager_proxy_config.backend_address,
        CommAddress.CREDIT_DROP: self.credit_drop_address,
        CommAddress.CREDIT_RETURN: self.credit_return_address,
        CommAddress.RECORDS: self.records_push_pull_address,
        CommAddress.RAW_INFERENCE_PROXY_FRONTEND: self.raw_inference_proxy_config.frontend_address,
        CommAddress.RAW_INFERENCE_PROXY_BACKEND: self.raw_inference_proxy_config.backend_address,
    }

    if address_type not in address_map:
        raise ValueError(f"Invalid address type: {address_type}")

    return address_map[address_type]

BaseZMQProxyConfig

Bases: BaseModel, ABC

Configuration Protocol for ZMQ Proxy.

Source code in aiperf/common/config/zmq_config.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class BaseZMQProxyConfig(BaseModel, ABC):
    """Configuration Protocol for ZMQ Proxy."""

    @property
    @abstractmethod
    def frontend_address(self) -> str:
        """Get the frontend address based on protocol configuration."""

    @property
    @abstractmethod
    def backend_address(self) -> str:
        """Get the backend address based on protocol configuration."""

    @property
    @abstractmethod
    def control_address(self) -> str | None:
        """Get the control address based on protocol configuration."""

    @property
    @abstractmethod
    def capture_address(self) -> str | None:
        """Get the capture address based on protocol configuration."""

backend_address abstractmethod property

Get the backend address based on protocol configuration.

capture_address abstractmethod property

Get the capture address based on protocol configuration.

control_address abstractmethod property

Get the control address based on protocol configuration.

frontend_address abstractmethod property

Get the frontend address based on protocol configuration.

ZMQIPCConfig

Bases: BaseZMQCommunicationConfig

Configuration for IPC transport.

Source code in aiperf/common/config/zmq_config.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class ZMQIPCConfig(BaseZMQCommunicationConfig):
    """Configuration for IPC transport."""

    path: str = Field(default="/tmp/aiperf", description="Path for IPC sockets")
    dataset_manager_proxy_config: ZMQIPCProxyConfig = Field(  # type: ignore
        default=ZMQIPCProxyConfig(name="dataset_manager_proxy"),
        description="Configuration for the ZMQ Dealer Router Proxy. If provided, the proxy will be created and started.",
    )
    event_bus_proxy_config: ZMQIPCProxyConfig = Field(  # type: ignore
        default=ZMQIPCProxyConfig(name="event_bus_proxy"),
        description="Configuration for the ZMQ XPUB/XSUB Proxy. If provided, the proxy will be created and started.",
    )
    raw_inference_proxy_config: ZMQIPCProxyConfig = Field(  # type: ignore
        default=ZMQIPCProxyConfig(name="raw_inference_proxy"),
        description="Configuration for the ZMQ Push/Pull Proxy. If provided, the proxy will be created and started.",
    )

    @property
    def records_push_pull_address(self) -> str:
        """Get the records push/pull address based on protocol configuration."""
        return f"ipc://{self.path}/records_push_pull.ipc"

    @property
    def credit_drop_address(self) -> str:
        """Get the credit drop address based on protocol configuration."""
        return f"ipc://{self.path}/credit_drop.ipc"

    @property
    def credit_return_address(self) -> str:
        """Get the credit return address based on protocol configuration."""
        return f"ipc://{self.path}/credit_return.ipc"

credit_drop_address property

Get the credit drop address based on protocol configuration.

credit_return_address property

Get the credit return address based on protocol configuration.

records_push_pull_address property

Get the records push/pull address based on protocol configuration.

ZMQIPCProxyConfig

Bases: BaseZMQProxyConfig

Configuration for IPC proxy.

Source code in aiperf/common/config/zmq_config.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class ZMQIPCProxyConfig(BaseZMQProxyConfig):
    """Configuration for IPC proxy."""

    path: str = Field(default="/tmp/aiperf", description="Path for IPC sockets")
    name: str = Field(default="proxy", description="Name for IPC sockets")
    enable_control: bool = Field(default=False, description="Enable control socket")
    enable_capture: bool = Field(default=False, description="Enable capture socket")

    @property
    def frontend_address(self) -> str:
        """Get the frontend address based on protocol configuration."""
        return f"ipc://{self.path}/{self.name}_frontend.ipc"

    @property
    def backend_address(self) -> str:
        """Get the backend address based on protocol configuration."""
        return f"ipc://{self.path}/{self.name}_backend.ipc"

    @property
    def control_address(self) -> str | None:
        """Get the control address based on protocol configuration."""
        return (
            f"ipc://{self.path}/{self.name}_control.ipc"
            if self.enable_control
            else None
        )

    @property
    def capture_address(self) -> str | None:
        """Get the capture address based on protocol configuration."""
        return (
            f"ipc://{self.path}/{self.name}_capture.ipc"
            if self.enable_capture
            else None
        )

backend_address property

Get the backend address based on protocol configuration.

capture_address property

Get the capture address based on protocol configuration.

control_address property

Get the control address based on protocol configuration.

frontend_address property

Get the frontend address based on protocol configuration.

ZMQTCPConfig

Bases: BaseZMQCommunicationConfig

Configuration for TCP transport.

Source code in aiperf/common/config/zmq_config.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class ZMQTCPConfig(BaseZMQCommunicationConfig):
    """Configuration for TCP transport."""

    host: str = Field(
        default="0.0.0.0",
        description="Host address for TCP connections",
    )
    records_push_pull_port: int = Field(
        default=5557, description="Port for inference push/pull messages"
    )
    credit_drop_port: int = Field(
        default=5562, description="Port for credit drop operations"
    )
    credit_return_port: int = Field(
        default=5563, description="Port for credit return operations"
    )
    dataset_manager_proxy_config: ZMQTCPProxyConfig = Field(  # type: ignore
        default=ZMQTCPProxyConfig(
            frontend_port=5661,
            backend_port=5662,
        ),
        description="Configuration for the ZMQ Proxy. If provided, the proxy will be created and started.",
    )
    event_bus_proxy_config: ZMQTCPProxyConfig = Field(  # type: ignore
        default=ZMQTCPProxyConfig(
            frontend_port=5663,
            backend_port=5664,
        ),
        description="Configuration for the ZMQ Proxy. If provided, the proxy will be created and started.",
    )
    raw_inference_proxy_config: ZMQTCPProxyConfig = Field(  # type: ignore
        default=ZMQTCPProxyConfig(
            frontend_port=5665,
            backend_port=5666,
        ),
        description="Configuration for the ZMQ Proxy. If provided, the proxy will be created and started.",
    )

    @property
    def records_push_pull_address(self) -> str:
        """Get the records push/pull address based on protocol configuration."""
        return f"tcp://{self.host}:{self.records_push_pull_port}"

    @property
    def credit_drop_address(self) -> str:
        """Get the credit drop address based on protocol configuration."""
        return f"tcp://{self.host}:{self.credit_drop_port}"

    @property
    def credit_return_address(self) -> str:
        """Get the credit return address based on protocol configuration."""
        return f"tcp://{self.host}:{self.credit_return_port}"

credit_drop_address property

Get the credit drop address based on protocol configuration.

credit_return_address property

Get the credit return address based on protocol configuration.

records_push_pull_address property

Get the records push/pull address based on protocol configuration.

ZMQTCPProxyConfig

Bases: BaseZMQProxyConfig

Configuration for TCP proxy.

Source code in aiperf/common/config/zmq_config.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class ZMQTCPProxyConfig(BaseZMQProxyConfig):
    """Configuration for TCP proxy."""

    host: str = Field(
        default="0.0.0.0",
        description="Host address for TCP connections",
    )
    frontend_port: int = Field(
        default=15555, description="Port for frontend address for proxy"
    )
    backend_port: int = Field(
        default=15556, description="Port for backend address for proxy"
    )
    control_port: int | None = Field(
        default=None, description="Port for control address for proxy"
    )
    capture_port: int | None = Field(
        default=None, description="Port for capture address for proxy"
    )

    @property
    def frontend_address(self) -> str:
        """Get the frontend address based on protocol configuration."""
        return f"tcp://{self.host}:{self.frontend_port}"

    @property
    def backend_address(self) -> str:
        """Get the backend address based on protocol configuration."""
        return f"tcp://{self.host}:{self.backend_port}"

    @property
    def control_address(self) -> str | None:
        """Get the control address based on protocol configuration."""
        return f"tcp://{self.host}:{self.control_port}" if self.control_port else None

    @property
    def capture_address(self) -> str | None:
        """Get the capture address based on protocol configuration."""
        return f"tcp://{self.host}:{self.capture_port}" if self.capture_port else None

backend_address property

Get the backend address based on protocol configuration.

capture_address property

Get the capture address based on protocol configuration.

control_address property

Get the control address based on protocol configuration.

frontend_address property

Get the frontend address based on protocol configuration.

aiperf.common.constants

DEFAULT_COMMAND_RESPONSE_TIMEOUT = 30.0 module-attribute

Default timeout for command responses in seconds.

DEFAULT_COMMS_REQUEST_TIMEOUT = 30.0 module-attribute

Default timeout for requests from req_clients to rep_clients in seconds.

DEFAULT_CONNECTION_PROBE_INTERVAL = 0.1 module-attribute

Default interval for connection probes in seconds until a response is received.

DEFAULT_CONNECTION_PROBE_TIMEOUT = 30.0 module-attribute

Maximum amount of time to wait for connection probe response.

DEFAULT_MAX_REGISTRATION_ATTEMPTS = 10 module-attribute

Default maximum number of registration attempts for component services before giving up.

DEFAULT_PROFILE_CANCEL_TIMEOUT = 10.0 module-attribute

Default timeout for cancelling a profile run in seconds.

DEFAULT_PROFILE_CONFIGURE_TIMEOUT = 300.0 module-attribute

Default timeout for profile configure command in seconds.

DEFAULT_PROFILE_START_TIMEOUT = 60.0 module-attribute

Default timeout for profile start command in seconds.

DEFAULT_PULL_CLIENT_MAX_CONCURRENCY = 100000 module-attribute

Default maximum concurrency for pull clients.

DEFAULT_REGISTRATION_INTERVAL = 1.0 module-attribute

Default interval between registration attempts in seconds for component services.

DEFAULT_SERVICE_REGISTRATION_TIMEOUT = 30.0 module-attribute

Default timeout for service registration in seconds.

DEFAULT_SERVICE_START_TIMEOUT = 30.0 module-attribute

Default timeout for service start in seconds.

DEFAULT_SHUTDOWN_ACK_TIMEOUT = 5.0 module-attribute

Default timeout for waiting for a shutdown command response in seconds.

GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 5.0 module-attribute

Default timeout for shutting down services in seconds.

TASK_CANCEL_TIMEOUT_LONG = 5.0 module-attribute

Maximum time to wait for complex tasks to complete when cancelling them (like parent tasks).

TASK_CANCEL_TIMEOUT_SHORT = 2.0 module-attribute

Maximum time to wait for simple tasks to complete when cancelling them.

aiperf.common.decorators

Decorators for AIPerf components. Note that these are not the same as hooks. Hooks are used to specify that a function should be called at a specific time, while decorators are used to specify that a class or function should be treated a specific way.

see also: :mod:aiperf.common.hooks for hook decorators.

DecoratorAttrs

Constant attribute names for decorators.

When you decorate a class with a decorator, the decorator type and parameters are set as attributes on the class.

Source code in aiperf/common/decorators.py
17
18
19
20
21
22
23
24
class DecoratorAttrs:
    """Constant attribute names for decorators.

    When you decorate a class with a decorator, the decorator type and parameters are
    set as attributes on the class.
    """

    IMPLEMENTS_PROTOCOL = "__implements_protocol__"

implements_protocol(protocol)

Decorator to specify that the class implements the given protocol.

Example:

@implements_protocol(ServiceProtocol)
class BaseService:
    pass

The above is the equivalent to setting:

BaseService.__implements_protocol__ = ServiceProtocol
Source code in aiperf/common/decorators.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def implements_protocol(protocol: type[ProtocolT]) -> Callable:
    """Decorator to specify that the class implements the given protocol.

    Example:
    ```python
    @implements_protocol(ServiceProtocol)
    class BaseService:
        pass
    ```

    The above is the equivalent to setting:
    ```python
    BaseService.__implements_protocol__ = ServiceProtocol
    ```
    """

    def decorator(cls: type[ClassProtocolT]) -> type[ClassProtocolT]:
        if TYPE_CHECKING:
            if not hasattr(protocol, "_is_runtime_protocol"):
                warn(
                    f"Protocol {protocol.__name__} is not a runtime protocol. "
                    "Please use the @runtime_checkable decorator to mark it as a runtime protocol.",
                    category=UserWarning,
                    stacklevel=2,
                )
                raise TypeError(
                    f"Protocol {protocol.__name__} is not a runtime protocol. "
                    "Please use the @runtime_checkable decorator to mark it as a runtime protocol."
                )
            if not issubclass(cls, protocol):
                warn(
                    f"Class {cls.__name__} does not implement the {protocol.__name__} protocol.",
                    category=UserWarning,
                    stacklevel=2,
                )
                raise TypeError(
                    f"Class {cls.__name__} does not implement the {protocol.__name__} protocol."
                )
        setattr(cls, DecoratorAttrs.IMPLEMENTS_PROTOCOL, protocol)
        return cls

    return decorator

aiperf.common.enums.base_enums

BasePydanticBackedStrEnum

Bases: CaseInsensitiveStrEnum

Custom enumeration class that extends CaseInsensitiveStrEnum and is backed by a BasePydanticEnumInfo that contains the tag, and any other information that is needed to represent the enum member.

Source code in aiperf/common/enums/base_enums.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class BasePydanticBackedStrEnum(CaseInsensitiveStrEnum):
    """
    Custom enumeration class that extends `CaseInsensitiveStrEnum`
    and is backed by a `BasePydanticEnumInfo` that contains the `tag`, and any other information that is needed
    to represent the enum member.
    """

    # Override the __new__ method to store the `BasePydanticEnumInfo` subclass model as an attribute. This is a python feature that
    # allows us to modify the behavior of the enum class's constructor. We use this to ensure the the enums still look like
    # a regular string enum, but also have the additional information stored as an attribute.
    def __new__(cls, info: BasePydanticEnumInfo) -> Self:
        # Create a new string object based on this class and the tag value.
        obj = str.__new__(cls, info.tag)
        # Ensure string value is set for comparison. This is how enums work internally.
        obj._value_ = info.tag
        # Store the Pydantic model as an attribute.
        obj._info: BasePydanticEnumInfo = info  # type: ignore
        return obj

    @cached_property
    def info(self) -> BasePydanticEnumInfo:
        """Get the enum info for the enum member."""
        # This is the Pydantic model that was stored as an attribute in the __new__ method.
        return self._info  # type: ignore

info cached property

Get the enum info for the enum member.

BasePydanticEnumInfo

Bases: BaseModel

Base class for all enum info classes that extend BasePydanticBackedStrEnum. By default, it provides a tag for the enum member, which is used for lookup and string comparison, and the subclass can provide additional information as needed.

Source code in aiperf/common/enums/base_enums.py
52
53
54
55
56
57
58
59
60
61
62
63
64
class BasePydanticEnumInfo(BaseModel):
    """Base class for all enum info classes that extend `BasePydanticBackedStrEnum`. By default, it
    provides a `tag` for the enum member, which is used for lookup and string comparison,
    and the subclass can provide additional information as needed."""

    tag: str = Field(
        ...,
        min_length=1,
        description="The string value of the enum member used for lookup, serialization, and string insensitive comparison.",
    )

    def __str__(self) -> str:
        return self.tag

CaseInsensitiveStrEnum

Bases: str, Enum

CaseInsensitiveStrEnum is a custom enumeration class that extends str and Enum to provide case-insensitive lookup functionality for its members.

Source code in aiperf/common/enums/base_enums.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class CaseInsensitiveStrEnum(str, Enum):
    """
    CaseInsensitiveStrEnum is a custom enumeration class that extends `str` and `Enum` to provide case-insensitive
    lookup functionality for its members.
    """

    def __str__(self) -> str:
        return self.value

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}.{self.name}"

    def __eq__(self, other: object) -> bool:
        if isinstance(other, str):
            return self.value.lower() == other.lower()
        if isinstance(other, Enum):
            return self.value.lower() == other.value.lower()
        return super().__eq__(other)

    def __hash__(self) -> int:
        return hash(self.value.lower())

    @classmethod
    def _missing_(cls, value):
        """
        Handles cases where a value is not directly found in the enumeration.

        This method is called when an attempt is made to access an enumeration
        member using a value that does not directly match any of the defined
        members. It provides custom logic to handle such cases.

        Returns:
            The matching enumeration member if a case-insensitive match is found
            for string values; otherwise, returns None.
        """
        if isinstance(value, str):
            for member in cls:
                if member.value.lower() == value.lower():
                    return member
        return None

aiperf.common.enums.benchmark_suite_enums

BenchmarkSuiteCompletionTrigger

Bases: CaseInsensitiveStrEnum

Determines how the suite completion is determined in order to know how to track the progress.

Source code in aiperf/common/enums/benchmark_suite_enums.py
 7
 8
 9
10
11
class BenchmarkSuiteCompletionTrigger(CaseInsensitiveStrEnum):
    """Determines how the suite completion is determined in order to know how to track the progress."""

    COMPLETED_PROFILES = "completed_profiles"
    """The suite will run until all profiles are completed."""

COMPLETED_PROFILES = 'completed_profiles' class-attribute instance-attribute

The suite will run until all profiles are completed.

BenchmarkSuiteType

Bases: CaseInsensitiveStrEnum

Determines the type of suite to know how to track the progress.

Source code in aiperf/common/enums/benchmark_suite_enums.py
19
20
21
22
23
class BenchmarkSuiteType(CaseInsensitiveStrEnum):
    """Determines the type of suite to know how to track the progress."""

    SINGLE_PROFILE = "single_profile"
    """A suite with a single profile run."""

SINGLE_PROFILE = 'single_profile' class-attribute instance-attribute

A suite with a single profile run.

aiperf.common.enums.command_enums

aiperf.common.enums.communication_enums

CommAddress

Bases: CaseInsensitiveStrEnum

Enum for specifying the address type for communication clients. This is used to lookup the address in the communication config.

Source code in aiperf/common/enums/communication_enums.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class CommAddress(CaseInsensitiveStrEnum):
    """Enum for specifying the address type for communication clients.
    This is used to lookup the address in the communication config."""

    EVENT_BUS_PROXY_FRONTEND = "event_bus_proxy_frontend"
    """Frontend address for services to publish messages to."""

    EVENT_BUS_PROXY_BACKEND = "event_bus_proxy_backend"
    """Backend address for services to subscribe to messages."""

    CREDIT_DROP = "credit_drop"
    """Address to send CreditDrop messages from the TimingManager to the Worker."""

    CREDIT_RETURN = "credit_return"
    """Address to send CreditReturn messages from the Worker to the TimingManager."""

    RECORDS = "records"
    """Address to send parsed records from InferenceParser to RecordManager."""

    DATASET_MANAGER_PROXY_FRONTEND = "dataset_manager_proxy_frontend"
    """Frontend address for sending requests to the DatasetManager."""

    DATASET_MANAGER_PROXY_BACKEND = "dataset_manager_proxy_backend"
    """Backend address for the DatasetManager to receive requests from clients."""

    RAW_INFERENCE_PROXY_FRONTEND = "raw_inference_proxy_frontend"
    """Frontend address for sending raw inference messages to the InferenceParser from Workers."""

    RAW_INFERENCE_PROXY_BACKEND = "raw_inference_proxy_backend"
    """Backend address for the InferenceParser to receive raw inference messages from Workers."""

CREDIT_DROP = 'credit_drop' class-attribute instance-attribute

Address to send CreditDrop messages from the TimingManager to the Worker.

CREDIT_RETURN = 'credit_return' class-attribute instance-attribute

Address to send CreditReturn messages from the Worker to the TimingManager.

DATASET_MANAGER_PROXY_BACKEND = 'dataset_manager_proxy_backend' class-attribute instance-attribute

Backend address for the DatasetManager to receive requests from clients.

DATASET_MANAGER_PROXY_FRONTEND = 'dataset_manager_proxy_frontend' class-attribute instance-attribute

Frontend address for sending requests to the DatasetManager.

EVENT_BUS_PROXY_BACKEND = 'event_bus_proxy_backend' class-attribute instance-attribute

Backend address for services to subscribe to messages.

EVENT_BUS_PROXY_FRONTEND = 'event_bus_proxy_frontend' class-attribute instance-attribute

Frontend address for services to publish messages to.

RAW_INFERENCE_PROXY_BACKEND = 'raw_inference_proxy_backend' class-attribute instance-attribute

Backend address for the InferenceParser to receive raw inference messages from Workers.

RAW_INFERENCE_PROXY_FRONTEND = 'raw_inference_proxy_frontend' class-attribute instance-attribute

Frontend address for sending raw inference messages to the InferenceParser from Workers.

RECORDS = 'records' class-attribute instance-attribute

Address to send parsed records from InferenceParser to RecordManager.

aiperf.common.enums.data_exporter_enums

aiperf.common.enums.dataset_enums

aiperf.common.enums.endpoints_enums

EndpointType

Bases: BasePydanticBackedStrEnum

Endpoint types supported by AIPerf.

These are the full definitions of the endpoints that are supported by AIPerf. Each enum value contains additional metadata about the endpoint, such as whether it supports streaming, produces tokens, and the default endpoint path. This is stored as an attribute on the enum value, and can be accessed via the info property. The enum values can still be used as strings for user input and comparison (via the tag field).

Source code in aiperf/common/enums/endpoints_enums.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class EndpointType(BasePydanticBackedStrEnum):
    """Endpoint types supported by AIPerf.

    These are the full definitions of the endpoints that are supported by AIPerf.
    Each enum value contains additional metadata about the endpoint, such as whether it supports streaming,
    produces tokens, and the default endpoint path. This is stored as an attribute on the enum value, and can be accessed
    via the `info` property. The enum values can still be used as strings for user input and comparison (via the `tag` field).
    """

    OPENAI_CHAT_COMPLETIONS = EndpointTypeInfo(
        tag="chat",
        supports_streaming=True,
        produces_tokens=True,
        supports_audio=True,
        supports_images=True,
        endpoint_path="/v1/chat/completions",
        metrics_title="LLM Metrics",
    )
    OPENAI_COMPLETIONS = EndpointTypeInfo(
        tag="completions",
        supports_streaming=True,
        produces_tokens=True,
        endpoint_path="/v1/completions",
        metrics_title="LLM Metrics",
    )
    OPENAI_EMBEDDINGS = EndpointTypeInfo(
        tag="embeddings",
        supports_streaming=False,
        produces_tokens=False,
        endpoint_path="/v1/embeddings",
        metrics_title="Embeddings Metrics",
    )
    OPENAI_RESPONSES = EndpointTypeInfo(
        tag="responses",
        supports_streaming=True,
        produces_tokens=True,
        supports_audio=False,  # Not yet supported by OpenAI
        supports_images=True,
        endpoint_path="/v1/responses",
        metrics_title="LLM Metrics",
    )

    @cached_property
    def info(self) -> EndpointTypeInfo:
        """Get the endpoint info for the endpoint type."""
        return self._info  # type: ignore

    @property
    def supports_streaming(self) -> bool:
        """Return True if the endpoint supports streaming. This is used for validation of user input."""
        return self.info.supports_streaming

    @property
    def produces_tokens(self) -> bool:
        """Return True if the endpoint produces tokens. This is used to determine what metrics are applicable to the endpoint."""
        return self.info.produces_tokens

    @property
    def endpoint_path(self) -> str | None:
        """Get the default endpoint path for the endpoint type. If None, the endpoint does not have a specific path."""
        return self.info.endpoint_path

    @property
    def supports_audio(self) -> bool:
        """Return True if the endpoint supports audio input.
        This is used to determine what metrics are applicable to the endpoint, as well as what inputs can be used."""
        return self.info.supports_audio

    @property
    def supports_images(self) -> bool:
        """Return True if the endpoint supports image input.
        This is used to determine what metrics are applicable to the endpoint, as well as what inputs can be used."""
        return self.info.supports_images

    @property
    def metrics_title(self) -> str:
        """Get the metrics table title string for the endpoint type. If None, the default title is used."""
        return self.info.metrics_title or "Metrics"

endpoint_path property

Get the default endpoint path for the endpoint type. If None, the endpoint does not have a specific path.

info cached property

Get the endpoint info for the endpoint type.

metrics_title property

Get the metrics table title string for the endpoint type. If None, the default title is used.

produces_tokens property

Return True if the endpoint produces tokens. This is used to determine what metrics are applicable to the endpoint.

supports_audio property

Return True if the endpoint supports audio input. This is used to determine what metrics are applicable to the endpoint, as well as what inputs can be used.

supports_images property

Return True if the endpoint supports image input. This is used to determine what metrics are applicable to the endpoint, as well as what inputs can be used.

supports_streaming property

Return True if the endpoint supports streaming. This is used for validation of user input.

EndpointTypeInfo

Bases: BasePydanticEnumInfo

Pydantic model for endpoint-specific metadata. This model is used to store additional info on each EndpointType enum value.

For documentation on the fields, see the :class:EndpointType @property functions.

Source code in aiperf/common/enums/endpoints_enums.py
14
15
16
17
18
19
20
21
22
23
24
25
class EndpointTypeInfo(BasePydanticEnumInfo):
    """Pydantic model for endpoint-specific metadata. This model is used to store additional info on each EndpointType enum value.

    For documentation on the fields, see the :class:`EndpointType` @property functions.
    """

    supports_streaming: bool = Field(...)
    produces_tokens: bool = Field(...)
    supports_audio: bool = Field(default=False)
    supports_images: bool = Field(default=False)
    endpoint_path: str | None = Field(default=None)
    metrics_title: str | None = Field(default=None)

aiperf.common.enums.logging_enums

aiperf.common.enums.measurement_enums

aiperf.common.enums.media_enums

MediaType

Bases: CaseInsensitiveStrEnum

The various types of media (e.g. text, image, audio).

Source code in aiperf/common/enums/media_enums.py
 7
 8
 9
10
11
12
class MediaType(CaseInsensitiveStrEnum):
    """The various types of media (e.g. text, image, audio)."""

    TEXT = "text"
    IMAGE = "image"
    AUDIO = "audio"

aiperf.common.enums.message_enums

MessageType

Bases: CaseInsensitiveStrEnum

The various types of messages that can be sent between services.

The message type is used to determine what Pydantic model the message maps to, based on the message_type field in the message model. For detailed explanations of each message type, go to its definition in :mod:aiperf.common.messages.

Source code in aiperf/common/enums/message_enums.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class MessageType(CaseInsensitiveStrEnum):
    """The various types of messages that can be sent between services.

    The message type is used to determine what Pydantic model the message maps to,
    based on the message_type field in the message model. For detailed explanations
    of each message type, go to its definition in :mod:`aiperf.common.messages`.
    """

    ALL_RECORDS_RECEIVED = "all_records_received"
    COMMAND = "command"
    COMMAND_RESPONSE = "command_response"
    CONNECTION_PROBE = "connection_probe"
    CONVERSATION_REQUEST = "conversation_request"
    CONVERSATION_RESPONSE = "conversation_response"
    CONVERSATION_TURN_REQUEST = "conversation_turn_request"
    CONVERSATION_TURN_RESPONSE = "conversation_turn_response"
    CREDITS_COMPLETE = "credits_complete"
    CREDIT_DROP = "credit_drop"
    CREDIT_PHASE_COMPLETE = "credit_phase_complete"
    CREDIT_PHASE_PROGRESS = "credit_phase_progress"
    CREDIT_PHASE_SENDING_COMPLETE = "credit_phase_sending_complete"
    CREDIT_PHASE_START = "credit_phase_start"
    CREDIT_RETURN = "credit_return"
    DATASET_CONFIGURED_NOTIFICATION = "dataset_configured_notification"
    DATASET_TIMING_REQUEST = "dataset_timing_request"
    DATASET_TIMING_RESPONSE = "dataset_timing_response"
    ERROR = "error"
    HEARTBEAT = "heartbeat"
    INFERENCE_RESULTS = "inference_results"
    METRIC_RECORDS = "metric_records"
    NOTIFICATION = "notification"
    PARSED_INFERENCE_RESULTS = "parsed_inference_results"
    PROCESSING_STATS = "processing_stats"
    PROCESS_RECORDS_REQUEST = "process_records_request"
    PROCESS_RECORDS_RESPONSE = "process_records_response"
    PROCESS_RECORDS_RESULT = "process_records_result"
    PROFILE_ERROR = "profile_error"
    PROFILE_PROGRESS = "profile_progress"
    PROFILE_RESULTS = "profile_results"
    REGISTRATION = "registration"
    SERVICE_ERROR = "service_error"
    STATUS = "status"
    SWEEP_BEGIN = "sweep_begin"
    SWEEP_CONFIGURE = "sweep_configure"
    SWEEP_END = "sweep_end"
    SWEEP_ERROR = "sweep_error"
    SWEEP_PROGRESS = "sweep_progress"
    SWEEP_RESULTS = "sweep_results"
    UNKNOWN = "unknown"
    WORKER_HEALTH = "worker_health"

NotificationType

Bases: CaseInsensitiveStrEnum

Types of notifications that can be sent to other services.

Source code in aiperf/common/enums/message_enums.py
59
60
61
62
63
class NotificationType(CaseInsensitiveStrEnum):
    """Types of notifications that can be sent to other services."""

    DATASET_CONFIGURED = "dataset_configured"
    """A notification sent to notify other services that the dataset has been configured."""

DATASET_CONFIGURED = 'dataset_configured' class-attribute instance-attribute

A notification sent to notify other services that the dataset has been configured.

aiperf.common.enums.metric_enums

BaseMetricUnit

Bases: BasePydanticBackedStrEnum

Base class for all metric units.

Source code in aiperf/common/enums/metric_enums.py
45
46
47
48
49
50
51
52
53
54
55
class BaseMetricUnit(BasePydanticBackedStrEnum):
    """Base class for all metric units."""

    @cached_property
    def info(self) -> BaseMetricUnitInfo:
        """Get the info for the metric unit."""
        return self._info  # type: ignore

    def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
        """Convert a value from this unit to another unit. This is a passthrough to the info class."""
        return self.info.convert_to(other_unit, value)

info cached property

Get the info for the metric unit.

convert_to(other_unit, value)

Convert a value from this unit to another unit. This is a passthrough to the info class.

Source code in aiperf/common/enums/metric_enums.py
53
54
55
def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
    """Convert a value from this unit to another unit. This is a passthrough to the info class."""
    return self.info.convert_to(other_unit, value)

BaseMetricUnitInfo

Bases: BasePydanticEnumInfo

Base class for all metric units. Provides a base implementation for converting between units which can be overridden by subclasses to support more complex conversions.

Source code in aiperf/common/enums/metric_enums.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class BaseMetricUnitInfo(BasePydanticEnumInfo):
    """Base class for all metric units. Provides a base implementation for converting between units which
    can be overridden by subclasses to support more complex conversions.
    """

    def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
        """Convert a value from this unit to another unit."""
        # If the other unit is the same as this unit, return the value. This allows for chaining conversions,
        # as well as if a type does not have a conversion method, we do not want to raise an error if the conversion is a no-op.
        if other_unit == self:
            return value

        # Otherwise, we cannot convert between the two units.
        raise MetricUnitError(
            f"Cannot convert from '{self}' to '{other_unit}'.",
        )

convert_to(other_unit, value)

Convert a value from this unit to another unit.

Source code in aiperf/common/enums/metric_enums.py
32
33
34
35
36
37
38
39
40
41
42
def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
    """Convert a value from this unit to another unit."""
    # If the other unit is the same as this unit, return the value. This allows for chaining conversions,
    # as well as if a type does not have a conversion method, we do not want to raise an error if the conversion is a no-op.
    if other_unit == self:
        return value

    # Otherwise, we cannot convert between the two units.
    raise MetricUnitError(
        f"Cannot convert from '{self}' to '{other_unit}'.",
    )

GenericMetricUnit

Bases: BaseMetricUnit

Defines generic units for metrics. These dont have any extra information other than the tag, which is used for display purposes.

Source code in aiperf/common/enums/metric_enums.py
187
188
189
190
191
192
193
class GenericMetricUnit(BaseMetricUnit):
    """Defines generic units for metrics. These dont have any extra information other than the tag, which is used for display purposes."""

    COUNT = _unit("count")
    REQUESTS = _unit("requests")
    TOKENS = _unit("tokens")
    USER = _unit("user")

MetricDateTimeUnit

Bases: BaseMetricUnit

Defines the various date time units that can be used for metrics.

Source code in aiperf/common/enums/metric_enums.py
196
197
198
199
class MetricDateTimeUnit(BaseMetricUnit):
    """Defines the various date time units that can be used for metrics."""

    DATE_TIME = _unit("datetime")

MetricFlags

Bases: Flag

Defines the possible flags for metrics that are used to determine how they are processed or grouped. These flags are intended to be an easy way to group metrics, or turn on/off certain features.

Note that the flags are a bitmask, so they can be combined using the bitwise OR operator (|). For example, to create a flag that is both STREAMING_ONLY and HIDDEN, you can do:

MetricFlags.STREAMING_ONLY | MetricFlags.HIDDEN

To check if a metric has a flag, you can use the has_flags method. For example, to check if a metric has both the STREAMING_ONLY and HIDDEN flags, you can do:

metric.has_flags(MetricFlags.STREAMING_ONLY | MetricFlags.HIDDEN)

To check if a metric does not have a flag(s), you can use the missing_flags method. For example, to check if a metric does not have either the STREAMING_ONLY or HIDDEN flags, you can do:

metric.missing_flags(MetricFlags.STREAMING_ONLY | MetricFlags.HIDDEN)
Source code in aiperf/common/enums/metric_enums.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
class MetricFlags(Flag):
    """Defines the possible flags for metrics that are used to determine how they are processed or grouped.
    These flags are intended to be an easy way to group metrics, or turn on/off certain features.

    Note that the flags are a bitmask, so they can be combined using the bitwise OR operator (`|`).
    For example, to create a flag that is both `STREAMING_ONLY` and `HIDDEN`, you can do:
    ```python
    MetricFlags.STREAMING_ONLY | MetricFlags.HIDDEN
    ```

    To check if a metric has a flag, you can use the `has_flags` method.
    For example, to check if a metric has both the `STREAMING_ONLY` and `HIDDEN` flags, you can do:
    ```python
    metric.has_flags(MetricFlags.STREAMING_ONLY | MetricFlags.HIDDEN)
    ```

    To check if a metric does not have a flag(s), you can use the `missing_flags` method.
    For example, to check if a metric does not have either the `STREAMING_ONLY` or `HIDDEN` flags, you can do:
    ```python
    metric.missing_flags(MetricFlags.STREAMING_ONLY | MetricFlags.HIDDEN)
    ```
    """

    # NOTE: The flags are a bitmask, so they must be powers of 2 (or a combination thereof).

    NONE = 0
    """No flags."""

    STREAMING_ONLY = 1 << 0
    """Metrics that are only applicable to streamed responses."""

    ERROR_ONLY = 1 << 1
    """Metrics that are only applicable to error records. By default, metrics are only computed if the record is valid.
    If this flag is set, the metric will only be computed if the record is invalid."""

    PRODUCES_TOKENS_ONLY = 1 << 2
    """Metrics that are only applicable when profiling an endpoint that produces tokens."""

    HIDDEN = 1 << 3
    """Metrics that should not be displayed in the UI."""

    LARGER_IS_BETTER = 1 << 4
    """Metrics that are better when the value is larger. By default, it is assumed that metrics are
    better when the value is smaller."""

    INTERNAL = (1 << 5) | HIDDEN
    """Metrics that are internal to the system and not applicable to the user. This inherently means that the metric
    is HIDDEN as well."""

    SUPPORTS_AUDIO_ONLY = 1 << 6
    """Metrics that are only applicable to audio-based endpoints."""

    SUPPORTS_IMAGE_ONLY = 1 << 7
    """Metrics that are only applicable to image-based endpoints."""

    STREAMING_TOKENS_ONLY = STREAMING_ONLY | PRODUCES_TOKENS_ONLY
    """Metrics that are only applicable to streamed responses and token-based endpoints.
    This is a convenience flag that is the combination of the `STREAMING_ONLY` and `PRODUCES_TOKENS_ONLY` flags."""

    def has_flags(self, flags: "MetricFlags") -> bool:
        """Return True if the metric has ALL of the given flag(s) (regardless of other flags)."""
        # Bitwise AND will return the input flags only if all of the given flags are present.
        return (flags & self) == flags

    def missing_flags(self, flags: "MetricFlags") -> bool:
        """Return True if the metric does not have ANY of the given flag(s) (regardless of other flags). It will
        return False if the metric has ANY of the given flags. If the input flags are NONE, it will return True."""
        if flags == MetricFlags.NONE:
            return True  # If there are no flags to check, return True

        # Bitwise AND will return 0 (MetricFlags.NONE) if there are no common flags.
        # If there are some missing, but some found, the result will not be 0.
        return (self & flags) == MetricFlags.NONE

ERROR_ONLY = 1 << 1 class-attribute instance-attribute

Metrics that are only applicable to error records. By default, metrics are only computed if the record is valid. If this flag is set, the metric will only be computed if the record is invalid.

HIDDEN = 1 << 3 class-attribute instance-attribute

Metrics that should not be displayed in the UI.

INTERNAL = 1 << 5 | HIDDEN class-attribute instance-attribute

Metrics that are internal to the system and not applicable to the user. This inherently means that the metric is HIDDEN as well.

LARGER_IS_BETTER = 1 << 4 class-attribute instance-attribute

Metrics that are better when the value is larger. By default, it is assumed that metrics are better when the value is smaller.

NONE = 0 class-attribute instance-attribute

No flags.

PRODUCES_TOKENS_ONLY = 1 << 2 class-attribute instance-attribute

Metrics that are only applicable when profiling an endpoint that produces tokens.

STREAMING_ONLY = 1 << 0 class-attribute instance-attribute

Metrics that are only applicable to streamed responses.

STREAMING_TOKENS_ONLY = STREAMING_ONLY | PRODUCES_TOKENS_ONLY class-attribute instance-attribute

Metrics that are only applicable to streamed responses and token-based endpoints. This is a convenience flag that is the combination of the STREAMING_ONLY and PRODUCES_TOKENS_ONLY flags.

SUPPORTS_AUDIO_ONLY = 1 << 6 class-attribute instance-attribute

Metrics that are only applicable to audio-based endpoints.

SUPPORTS_IMAGE_ONLY = 1 << 7 class-attribute instance-attribute

Metrics that are only applicable to image-based endpoints.

has_flags(flags)

Return True if the metric has ALL of the given flag(s) (regardless of other flags).

Source code in aiperf/common/enums/metric_enums.py
422
423
424
425
def has_flags(self, flags: "MetricFlags") -> bool:
    """Return True if the metric has ALL of the given flag(s) (regardless of other flags)."""
    # Bitwise AND will return the input flags only if all of the given flags are present.
    return (flags & self) == flags

missing_flags(flags)

Return True if the metric does not have ANY of the given flag(s) (regardless of other flags). It will return False if the metric has ANY of the given flags. If the input flags are NONE, it will return True.

Source code in aiperf/common/enums/metric_enums.py
427
428
429
430
431
432
433
434
435
def missing_flags(self, flags: "MetricFlags") -> bool:
    """Return True if the metric does not have ANY of the given flag(s) (regardless of other flags). It will
    return False if the metric has ANY of the given flags. If the input flags are NONE, it will return True."""
    if flags == MetricFlags.NONE:
        return True  # If there are no flags to check, return True

    # Bitwise AND will return 0 (MetricFlags.NONE) if there are no common flags.
    # If there are some missing, but some found, the result will not be 0.
    return (self & flags) == MetricFlags.NONE

MetricOverTimeUnit

Bases: BaseMetricUnit

Defines the units for metrics that are a generic unit over a specific time unit.

Source code in aiperf/common/enums/metric_enums.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
class MetricOverTimeUnit(BaseMetricUnit):
    """Defines the units for metrics that are a generic unit over a specific time unit."""

    REQUESTS_PER_SECOND = MetricOverTimeUnitInfo(
        primary_unit=GenericMetricUnit.REQUESTS,
        time_unit=MetricTimeUnit.SECONDS,
    )
    TOKENS_PER_SECOND = MetricOverTimeUnitInfo(
        primary_unit=GenericMetricUnit.TOKENS,
        time_unit=MetricTimeUnit.SECONDS,
    )
    TOKENS_PER_SECOND_PER_USER = MetricOverTimeUnitInfo(
        primary_unit=GenericMetricUnit.TOKENS,
        time_unit=MetricTimeUnit.SECONDS,
        third_unit=GenericMetricUnit.USER,
    )

    @cached_property
    def info(self) -> MetricOverTimeUnitInfo:
        """Get the info for the metric over time unit."""
        return self._info  # type: ignore

    @cached_property
    def primary_unit(self) -> "MetricUnitT":
        """Get the primary unit."""
        return self.info.primary_unit

    @cached_property
    def time_unit(self) -> MetricTimeUnit | MetricTimeUnitInfo:
        """Get the time unit."""
        return self.info.time_unit

    @cached_property
    def third_unit(self) -> "MetricUnitT | None":
        """Get the third unit (if applicable)."""
        return self.info.third_unit

info cached property

Get the info for the metric over time unit.

primary_unit cached property

Get the primary unit.

third_unit cached property

Get the third unit (if applicable).

time_unit cached property

Get the time unit.

MetricOverTimeUnitInfo

Bases: BaseMetricUnitInfo

Information about a metric over time unit.

Source code in aiperf/common/enums/metric_enums.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
class MetricOverTimeUnitInfo(BaseMetricUnitInfo):
    """Information about a metric over time unit."""

    @model_validator(mode="after")
    def _set_tag(self: Self) -> Self:
        """Set the tag based on the existing units. ie. requests/sec, tokens/sec, etc."""
        self.tag = f"{self.primary_unit}/{self.time_unit}"
        if self.third_unit:
            # If there is a third unit, add it to the tag. ie. tokens/sec/user
            self.tag += f"/{self.third_unit}"
        return self

    tag: str = Field(
        default="",
        description="The tag for the metric over time unit. This will be set automatically by the model_validator.",
    )
    primary_unit: "MetricUnitT"
    time_unit: MetricTimeUnit | MetricTimeUnitInfo
    third_unit: "MetricUnitT | None" = None

    def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
        """Convert a value from this unit to another unit."""
        # If the other unit is the same as this unit, return the value.
        if other_unit == self:
            return value

        if isinstance(other_unit, MetricOverTimeUnit | MetricOverTimeUnitInfo):
            # Chain convert each unit to the other unit.
            value = self.primary_unit.convert_to(other_unit.primary_unit, value)
            value = self.time_unit.convert_to(other_unit.time_unit, value)
            if self.third_unit and other_unit.third_unit:
                value = self.third_unit.convert_to(other_unit.third_unit, value)
            return value

        # If the other unit is a time unit, convert our time unit to the other unit.
        # TODO: Should we even allow this?
        if isinstance(other_unit, MetricTimeUnit | MetricTimeUnitInfo):
            return self.time_unit.convert_to(other_unit, value)

        # Otherwise, convert the primary unit to the other unit.
        return self.primary_unit.convert_to(other_unit, value)

convert_to(other_unit, value)

Convert a value from this unit to another unit.

Source code in aiperf/common/enums/metric_enums.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
    """Convert a value from this unit to another unit."""
    # If the other unit is the same as this unit, return the value.
    if other_unit == self:
        return value

    if isinstance(other_unit, MetricOverTimeUnit | MetricOverTimeUnitInfo):
        # Chain convert each unit to the other unit.
        value = self.primary_unit.convert_to(other_unit.primary_unit, value)
        value = self.time_unit.convert_to(other_unit.time_unit, value)
        if self.third_unit and other_unit.third_unit:
            value = self.third_unit.convert_to(other_unit.third_unit, value)
        return value

    # If the other unit is a time unit, convert our time unit to the other unit.
    # TODO: Should we even allow this?
    if isinstance(other_unit, MetricTimeUnit | MetricTimeUnitInfo):
        return self.time_unit.convert_to(other_unit, value)

    # Otherwise, convert the primary unit to the other unit.
    return self.primary_unit.convert_to(other_unit, value)

MetricSizeUnit

Bases: BaseMetricUnit

Defines the size types for metrics.

Source code in aiperf/common/enums/metric_enums.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class MetricSizeUnit(BaseMetricUnit):
    """Defines the size types for metrics."""

    BYTES = MetricSizeUnitInfo(
        tag="B",
        long_name="bytes",
        num_bytes=1,
    )
    KILOBYTES = MetricSizeUnitInfo(
        tag="KB",
        long_name="kilobytes",
        num_bytes=1024,
    )
    MEGABYTES = MetricSizeUnitInfo(
        tag="MB",
        long_name="megabytes",
        num_bytes=1024 * 1024,
    )
    GIGABYTES = MetricSizeUnitInfo(
        tag="GB",
        long_name="gigabytes",
        num_bytes=1024 * 1024 * 1024,
    )
    TERABYTES = MetricSizeUnitInfo(
        tag="TB",
        long_name="terabytes",
        num_bytes=1024 * 1024 * 1024 * 1024,
    )

    @cached_property
    def info(self) -> MetricSizeUnitInfo:
        """Get the info for the metric size unit."""
        return self._info  # type: ignore

    @cached_property
    def num_bytes(self) -> int:
        """The number of bytes in the metric size unit."""
        return self.info.num_bytes

    @cached_property
    def long_name(self) -> str:
        """The long name of the metric size unit."""
        return self.info.long_name

info cached property

Get the info for the metric size unit.

long_name cached property

The long name of the metric size unit.

num_bytes cached property

The number of bytes in the metric size unit.

MetricSizeUnitInfo

Bases: BaseMetricUnitInfo

Information about a size unit for metrics.

Source code in aiperf/common/enums/metric_enums.py
62
63
64
65
66
67
68
69
70
71
72
73
class MetricSizeUnitInfo(BaseMetricUnitInfo):
    """Information about a size unit for metrics."""

    long_name: str
    num_bytes: int

    def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
        """Convert a value from this unit to another unit."""
        if not isinstance(other_unit, MetricSizeUnit | MetricSizeUnitInfo):
            return super().convert_to(other_unit, value)

        return value * (self.num_bytes / other_unit.num_bytes)

convert_to(other_unit, value)

Convert a value from this unit to another unit.

Source code in aiperf/common/enums/metric_enums.py
68
69
70
71
72
73
def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
    """Convert a value from this unit to another unit."""
    if not isinstance(other_unit, MetricSizeUnit | MetricSizeUnitInfo):
        return super().convert_to(other_unit, value)

    return value * (self.num_bytes / other_unit.num_bytes)

MetricTimeUnit

Bases: BaseMetricUnit

Defines the various time units that can be used for metrics, as well as the conversion factor to convert to other units.

Source code in aiperf/common/enums/metric_enums.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class MetricTimeUnit(BaseMetricUnit):
    """Defines the various time units that can be used for metrics, as well as the conversion factor to convert to other units."""

    NANOSECONDS = MetricTimeUnitInfo(
        tag="ns",
        long_name="nanoseconds",
        per_second=1_000_000_000,
    )
    MICROSECONDS = MetricTimeUnitInfo(
        tag="us",
        long_name="microseconds",
        per_second=1_000_000,
    )
    MILLISECONDS = MetricTimeUnitInfo(
        tag="ms",
        long_name="milliseconds",
        per_second=1_000,
    )
    SECONDS = MetricTimeUnitInfo(
        tag="sec",
        long_name="seconds",
        per_second=1,
    )

    @cached_property
    def info(self) -> MetricTimeUnitInfo:
        """Get the info for the metric time unit."""
        return self._info  # type: ignore

    @cached_property
    def per_second(self) -> int:
        """How many of these units there are in one second. Used as a common conversion factor to convert to other units."""
        return self.info.per_second

    @cached_property
    def long_name(self) -> str:
        """The long name of the metric time unit."""
        return self.info.long_name

    def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
        """Convert a value from this unit to another unit."""
        if not isinstance(
            other_unit, MetricTimeUnit | MetricTimeUnitInfo | MetricDateTimeUnit
        ):
            return super().convert_to(other_unit, value)

        if isinstance(other_unit, MetricDateTimeUnit):
            return datetime.fromtimestamp(
                self.convert_to(MetricTimeUnit.SECONDS, value)
            )

        return value * (other_unit.per_second / self.per_second)

info cached property

Get the info for the metric time unit.

long_name cached property

The long name of the metric time unit.

per_second cached property

How many of these units there are in one second. Used as a common conversion factor to convert to other units.

convert_to(other_unit, value)

Convert a value from this unit to another unit.

Source code in aiperf/common/enums/metric_enums.py
167
168
169
170
171
172
173
174
175
176
177
178
179
def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
    """Convert a value from this unit to another unit."""
    if not isinstance(
        other_unit, MetricTimeUnit | MetricTimeUnitInfo | MetricDateTimeUnit
    ):
        return super().convert_to(other_unit, value)

    if isinstance(other_unit, MetricDateTimeUnit):
        return datetime.fromtimestamp(
            self.convert_to(MetricTimeUnit.SECONDS, value)
        )

    return value * (other_unit.per_second / self.per_second)

MetricTimeUnitInfo

Bases: BaseMetricUnitInfo

Information about a time unit for metrics.

Source code in aiperf/common/enums/metric_enums.py
121
122
123
124
125
class MetricTimeUnitInfo(BaseMetricUnitInfo):
    """Information about a time unit for metrics."""

    long_name: str
    per_second: int

MetricType

Bases: CaseInsensitiveStrEnum

Defines the possible types of metrics.

Source code in aiperf/common/enums/metric_enums.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
class MetricType(CaseInsensitiveStrEnum):
    """Defines the possible types of metrics."""

    RECORD = "record"
    """Metrics that provide a distinct value for each request. Every request that comes in will produce a new value that is not affected by any other requests.
    These metrics can be tracked over time and compared to each other.
    Examples: request latency, ISL, ITL, OSL, etc."""

    AGGREGATE = "aggregate"
    """Metrics that keep track of one or more values over time, that are updated for each request, such as total counts, min/max values, etc.
    These metrics may or may not change each request, and are affected by other requests.
    Examples: min/max request latency, total request count, benchmark duration, etc."""

    DERIVED = "derived"
    """Metrics that are purely derived from other metrics as a summary, and do not require per-request values.
    Examples: request throughput, output token throughput, etc."""

AGGREGATE = 'aggregate' class-attribute instance-attribute

Metrics that keep track of one or more values over time, that are updated for each request, such as total counts, min/max values, etc. These metrics may or may not change each request, and are affected by other requests. Examples: min/max request latency, total request count, benchmark duration, etc.

DERIVED = 'derived' class-attribute instance-attribute

Metrics that are purely derived from other metrics as a summary, and do not require per-request values. Examples: request throughput, output token throughput, etc.

RECORD = 'record' class-attribute instance-attribute

Metrics that provide a distinct value for each request. Every request that comes in will produce a new value that is not affected by any other requests. These metrics can be tracked over time and compared to each other. Examples: request latency, ISL, ITL, OSL, etc.

MetricValueType

Bases: BasePydanticBackedStrEnum

Defines the possible types of values for metrics.

NOTE: The string representation is important here, as it is used to automatically determine the type based on the python generic type definition.

Source code in aiperf/common/enums/metric_enums.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class MetricValueType(BasePydanticBackedStrEnum):
    """Defines the possible types of values for metrics.

    NOTE: The string representation is important here, as it is used to automatically determine the type
    based on the python generic type definition.
    """

    FLOAT = MetricValueTypeInfo(
        tag="float",
        default_factory=float,
        converter=float,
        dtype=float,
    )
    INT = MetricValueTypeInfo(
        tag="int",
        default_factory=int,
        converter=int,
        dtype=int,
    )

    @cached_property
    def info(self) -> MetricValueTypeInfo:
        """Get the info for the metric value type."""
        return self._info  # type: ignore

    @cached_property
    def default_factory(self) -> Callable[[], MetricValueTypeT]:
        """Get the default value generator for the metric value type."""
        return self.info.default_factory

    @cached_property
    def converter(self) -> Callable[[Any], MetricValueTypeT]:
        """Get the converter for the metric value type."""
        return self.info.converter

    @cached_property
    def dtype(self) -> Any:
        """Get the dtype for the metric value type (for pandas/numpy)."""
        return self.info.dtype

    @classmethod
    def from_python_type(cls, type: type[MetricValueTypeT]) -> "MetricValueType":
        """Get the MetricValueType for a given type."""
        # If the type is a simple type like float or int, we have to use __name__.
        # This is because using str() on float or int will return <class 'float'> or <class 'int'>, etc.
        type_name = type.__name__
        if type_name == "list":
            # However, if the type is a list, we have to use str() to get the list type as well, e.g. list[int]
            type_name = str(type)
        elif type_name == "MetricValueTypeVarT":
            type_name = "float"  # Default to float if the user did not specify a type.
        return MetricValueType(type_name)

converter cached property

Get the converter for the metric value type.

default_factory cached property

Get the default value generator for the metric value type.

dtype cached property

Get the dtype for the metric value type (for pandas/numpy).

info cached property

Get the info for the metric value type.

from_python_type(type) classmethod

Get the MetricValueType for a given type.

Source code in aiperf/common/enums/metric_enums.py
349
350
351
352
353
354
355
356
357
358
359
360
@classmethod
def from_python_type(cls, type: type[MetricValueTypeT]) -> "MetricValueType":
    """Get the MetricValueType for a given type."""
    # If the type is a simple type like float or int, we have to use __name__.
    # This is because using str() on float or int will return <class 'float'> or <class 'int'>, etc.
    type_name = type.__name__
    if type_name == "list":
        # However, if the type is a list, we have to use str() to get the list type as well, e.g. list[int]
        type_name = str(type)
    elif type_name == "MetricValueTypeVarT":
        type_name = "float"  # Default to float if the user did not specify a type.
    return MetricValueType(type_name)

MetricValueTypeInfo

Bases: BasePydanticEnumInfo

Information about a metric value type.

Source code in aiperf/common/enums/metric_enums.py
301
302
303
304
305
306
class MetricValueTypeInfo(BasePydanticEnumInfo):
    """Information about a metric value type."""

    default_factory: Callable[[], MetricValueTypeT]
    converter: Callable[[Any], MetricValueTypeT]
    dtype: Any

aiperf.common.enums.model_enums

Modality

Bases: CaseInsensitiveStrEnum

Modality of the model. Can be used to determine the type of data to send to the model in conjunction with the ModelSelectionStrategy.MODALITY_AWARE.

Source code in aiperf/common/enums/model_enums.py
 7
 8
 9
10
11
12
13
14
15
16
class Modality(CaseInsensitiveStrEnum):
    """Modality of the model. Can be used to determine the type of data to send to the model in
    conjunction with the ModelSelectionStrategy.MODALITY_AWARE."""

    TEXT = "text"
    IMAGE = "image"
    AUDIO = "audio"
    VIDEO = "video"
    MULTIMODAL = "multimodal"
    CUSTOM = "custom"

ModelSelectionStrategy

Bases: CaseInsensitiveStrEnum

Strategy for selecting the model to use for the request.

Source code in aiperf/common/enums/model_enums.py
19
20
21
22
23
class ModelSelectionStrategy(CaseInsensitiveStrEnum):
    """Strategy for selecting the model to use for the request."""

    ROUND_ROBIN = "round_robin"
    RANDOM = "random"

aiperf.common.enums.post_processor_enums

RecordProcessorType

Bases: CaseInsensitiveStrEnum

Type of streaming record processor.

Source code in aiperf/common/enums/post_processor_enums.py
 7
 8
 9
10
11
12
class RecordProcessorType(CaseInsensitiveStrEnum):
    """Type of streaming record processor."""

    METRIC_RECORD = "metric_record"
    """Streamer that streams records and computes metrics from MetricType.RECORD and MetricType.AGGREGATE.
    This is the first stage of the metrics processing pipeline, and is done is a distributed manner across multiple service instances."""

METRIC_RECORD = 'metric_record' class-attribute instance-attribute

Streamer that streams records and computes metrics from MetricType.RECORD and MetricType.AGGREGATE. This is the first stage of the metrics processing pipeline, and is done is a distributed manner across multiple service instances.

ResultsProcessorType

Bases: CaseInsensitiveStrEnum

Type of streaming results processor.

Source code in aiperf/common/enums/post_processor_enums.py
15
16
17
18
19
20
class ResultsProcessorType(CaseInsensitiveStrEnum):
    """Type of streaming results processor."""

    METRIC_RESULTS = "metric_results"
    """Processor that processes the metric results from METRIC_RECORD and computes metrics from MetricType.DERIVED. as well as aggregates the results.
    This is the last stage of the metrics processing pipeline, and is done from the RecordsManager after all the service instances have completed their processing."""

METRIC_RESULTS = 'metric_results' class-attribute instance-attribute

Processor that processes the metric results from METRIC_RECORD and computes metrics from MetricType.DERIVED. as well as aggregates the results. This is the last stage of the metrics processing pipeline, and is done from the RecordsManager after all the service instances have completed their processing.

aiperf.common.enums.service_enums

LifecycleState

Bases: CaseInsensitiveStrEnum

This is the various states a lifecycle can be in.

Source code in aiperf/common/enums/service_enums.py
19
20
21
22
23
24
25
26
27
28
29
class LifecycleState(CaseInsensitiveStrEnum):
    """This is the various states a lifecycle can be in."""

    CREATED = "created"
    INITIALIZING = "initializing"
    INITIALIZED = "initialized"
    STARTING = "starting"
    RUNNING = "running"
    STOPPING = "stopping"
    STOPPED = "stopped"
    FAILED = "failed"

ServiceRegistrationStatus

Bases: CaseInsensitiveStrEnum

Defines the various states a service can be in during registration with the SystemController.

Source code in aiperf/common/enums/service_enums.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class ServiceRegistrationStatus(CaseInsensitiveStrEnum):
    """Defines the various states a service can be in during registration with
    the SystemController."""

    UNREGISTERED = "unregistered"
    """The service is not registered with the SystemController. This is the
    initial state."""

    WAITING = "waiting"
    """The service is waiting for the SystemController to register it.
    This is a temporary state that should be followed by REGISTERED, TIMEOUT, or ERROR."""

    REGISTERED = "registered"
    """The service is registered with the SystemController."""

    TIMEOUT = "timeout"
    """The service registration timed out."""

    ERROR = "error"
    """The service registration failed."""

ERROR = 'error' class-attribute instance-attribute

The service registration failed.

REGISTERED = 'registered' class-attribute instance-attribute

The service is registered with the SystemController.

TIMEOUT = 'timeout' class-attribute instance-attribute

The service registration timed out.

UNREGISTERED = 'unregistered' class-attribute instance-attribute

The service is not registered with the SystemController. This is the initial state.

WAITING = 'waiting' class-attribute instance-attribute

The service is waiting for the SystemController to register it. This is a temporary state that should be followed by REGISTERED, TIMEOUT, or ERROR.

ServiceRunType

Bases: CaseInsensitiveStrEnum

The different ways the SystemController should run the component services.

Source code in aiperf/common/enums/service_enums.py
 7
 8
 9
10
11
12
13
14
15
16
class ServiceRunType(CaseInsensitiveStrEnum):
    """The different ways the SystemController should run the component services."""

    MULTIPROCESSING = "process"
    """Run each service as a separate process.
    This is the default way for single-node deployments."""

    KUBERNETES = "k8s"
    """Run each service as a separate Kubernetes pod.
    This is the default way for multi-node deployments."""

KUBERNETES = 'k8s' class-attribute instance-attribute

Run each service as a separate Kubernetes pod. This is the default way for multi-node deployments.

MULTIPROCESSING = 'process' class-attribute instance-attribute

Run each service as a separate process. This is the default way for single-node deployments.

ServiceType

Bases: CaseInsensitiveStrEnum

Types of services in the AIPerf system.

This is used to identify the service type when registering with the SystemController. It can also be used for tracking purposes if multiple instances of the same service type are running.

Source code in aiperf/common/enums/service_enums.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class ServiceType(CaseInsensitiveStrEnum):
    """Types of services in the AIPerf system.

    This is used to identify the service type when registering with the
    SystemController. It can also be used for tracking purposes if multiple
    instances of the same service type are running.
    """

    SYSTEM_CONTROLLER = "system_controller"
    DATASET_MANAGER = "dataset_manager"
    TIMING_MANAGER = "timing_manager"
    RECORD_PROCESSOR = "record_processor"
    RECORDS_MANAGER = "records_manager"
    WORKER_MANAGER = "worker_manager"
    WORKER = "worker"

    # For testing purposes only
    TEST = "test_service"

aiperf.common.enums.sse_enums

SSEEventType

Bases: CaseInsensitiveStrEnum

Event types in an SSE message. Many of these are custom and not defined by the SSE spec.

Source code in aiperf/common/enums/sse_enums.py
17
18
19
20
21
class SSEEventType(CaseInsensitiveStrEnum):
    """Event types in an SSE message. Many of these are custom and not defined by the SSE spec."""

    ERROR = "error"
    LLM_METRICS = "llm_metrics"

SSEFieldType

Bases: CaseInsensitiveStrEnum

Field types in an SSE message.

Source code in aiperf/common/enums/sse_enums.py
 7
 8
 9
10
11
12
13
14
class SSEFieldType(CaseInsensitiveStrEnum):
    """Field types in an SSE message."""

    DATA = "data"
    EVENT = "event"
    ID = "id"
    RETRY = "retry"
    COMMENT = "comment"

aiperf.common.enums.system_enums

SystemState

Bases: CaseInsensitiveStrEnum

State of the system as a whole.

This is used to track the state of the system as a whole, and is used to determine what actions to take when a signal is received.

Source code in aiperf/common/enums/system_enums.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class SystemState(CaseInsensitiveStrEnum):
    """State of the system as a whole.

    This is used to track the state of the system as a whole, and is used to
    determine what actions to take when a signal is received.
    """

    INITIALIZING = "initializing"
    """The system is initializing. This is the initial state."""

    CONFIGURING = "configuring"
    """The system is configuring services."""

    READY = "ready"
    """The system is ready to start profiling. This is a temporary state that should be
    followed by PROFILING."""

    PROFILING = "profiling"
    """The system is running a profiling run."""

    PROCESSING = "processing"
    """The system is processing results."""

    STOPPING = "stopping"
    """The system is stopping."""

    SHUTDOWN = "shutdown"
    """The system is shutting down. This is the final state."""

CONFIGURING = 'configuring' class-attribute instance-attribute

The system is configuring services.

INITIALIZING = 'initializing' class-attribute instance-attribute

The system is initializing. This is the initial state.

PROCESSING = 'processing' class-attribute instance-attribute

The system is processing results.

PROFILING = 'profiling' class-attribute instance-attribute

The system is running a profiling run.

READY = 'ready' class-attribute instance-attribute

The system is ready to start profiling. This is a temporary state that should be followed by PROFILING.

SHUTDOWN = 'shutdown' class-attribute instance-attribute

The system is shutting down. This is the final state.

STOPPING = 'stopping' class-attribute instance-attribute

The system is stopping.

aiperf.common.enums.timing_enums

CreditPhase

Bases: CaseInsensitiveStrEnum

The type of credit phase. This is used to identify which phase of the benchmark the credit is being used in, for tracking and reporting purposes.

Source code in aiperf/common/enums/timing_enums.py
30
31
32
33
34
35
36
37
38
39
40
class CreditPhase(CaseInsensitiveStrEnum):
    """The type of credit phase. This is used to identify which phase of the
    benchmark the credit is being used in, for tracking and reporting purposes."""

    WARMUP = "warmup"
    """The credit phase is the warmup phase. This is used to warm up the model
    before the benchmark starts."""

    PROFILING = "profiling"
    """The credit phase is the steady state phase. This is the primary phase of the
    benchmark, and what is used to calculate the final results."""

PROFILING = 'profiling' class-attribute instance-attribute

The credit phase is the steady state phase. This is the primary phase of the benchmark, and what is used to calculate the final results.

WARMUP = 'warmup' class-attribute instance-attribute

The credit phase is the warmup phase. This is used to warm up the model before the benchmark starts.

RequestRateMode

Bases: CaseInsensitiveStrEnum

The different ways the RequestRateStrategy should generate requests.

Source code in aiperf/common/enums/timing_enums.py
20
21
22
23
24
25
26
27
class RequestRateMode(CaseInsensitiveStrEnum):
    """The different ways the RequestRateStrategy should generate requests."""

    CONSTANT = "constant"
    """Generate requests at a constant rate."""

    POISSON = "poisson"
    """Generate requests using a poisson distribution."""

CONSTANT = 'constant' class-attribute instance-attribute

Generate requests at a constant rate.

POISSON = 'poisson' class-attribute instance-attribute

Generate requests using a poisson distribution.

TimingMode

Bases: CaseInsensitiveStrEnum

The different ways the TimingManager should generate requests.

Source code in aiperf/common/enums/timing_enums.py
 7
 8
 9
10
11
12
13
14
15
16
17
class TimingMode(CaseInsensitiveStrEnum):
    """The different ways the TimingManager should generate requests."""

    FIXED_SCHEDULE = "fixed_schedule"
    """A mode where the TimingManager will send requests according to a fixed schedule."""

    CONCURRENCY = "concurrency"
    """A mode where the TimingManager will maintain a continuous stream of concurrent requests."""

    REQUEST_RATE = "request_rate"
    """A mode where the TimingManager will send requests at either a constant request rate or based on a poisson distribution."""

CONCURRENCY = 'concurrency' class-attribute instance-attribute

A mode where the TimingManager will maintain a continuous stream of concurrent requests.

FIXED_SCHEDULE = 'fixed_schedule' class-attribute instance-attribute

A mode where the TimingManager will send requests according to a fixed schedule.

REQUEST_RATE = 'request_rate' class-attribute instance-attribute

A mode where the TimingManager will send requests at either a constant request rate or based on a poisson distribution.

aiperf.common.exceptions

AIPerfError

Bases: Exception

Base class for all exceptions raised by AIPerf.

Source code in aiperf/common/exceptions.py
10
11
12
13
14
15
16
17
18
19
class AIPerfError(Exception):
    """Base class for all exceptions raised by AIPerf."""

    def raw_str(self) -> str:
        """Return the raw string representation of the exception."""
        return super().__str__()

    def __str__(self) -> str:
        """Return the string representation of the exception with the class name."""
        return f"{self.__class__.__name__}: {super().__str__()}"

__str__()

Return the string representation of the exception with the class name.

Source code in aiperf/common/exceptions.py
17
18
19
def __str__(self) -> str:
    """Return the string representation of the exception with the class name."""
    return f"{self.__class__.__name__}: {super().__str__()}"

raw_str()

Return the raw string representation of the exception.

Source code in aiperf/common/exceptions.py
13
14
15
def raw_str(self) -> str:
    """Return the raw string representation of the exception."""
    return super().__str__()

AIPerfMultiError

Bases: AIPerfError

Exception raised when running multiple tasks and one or more fail.

Source code in aiperf/common/exceptions.py
22
23
24
25
26
27
28
29
30
class AIPerfMultiError(AIPerfError):
    """Exception raised when running multiple tasks and one or more fail."""

    def __init__(self, message: str, exceptions: list[Exception]) -> None:
        err_strings = [
            e.raw_str() if isinstance(e, AIPerfError) else str(e) for e in exceptions
        ]
        super().__init__(f"{message}: {','.join(err_strings)}")
        self.exceptions = exceptions

CommunicationError

Bases: AIPerfError

Generic communication error.

Source code in aiperf/common/exceptions.py
49
50
class CommunicationError(AIPerfError):
    """Generic communication error."""

ConfigurationError

Bases: AIPerfError

Exception raised when something fails to configure, or there is a configuration error.

Source code in aiperf/common/exceptions.py
53
54
class ConfigurationError(AIPerfError):
    """Exception raised when something fails to configure, or there is a configuration error."""

DatasetError

Bases: AIPerfError

Generic dataset error.

Source code in aiperf/common/exceptions.py
57
58
class DatasetError(AIPerfError):
    """Generic dataset error."""

DatasetGeneratorError

Bases: AIPerfError

Generic dataset generator error.

Source code in aiperf/common/exceptions.py
61
62
class DatasetGeneratorError(AIPerfError):
    """Generic dataset generator error."""

FactoryCreationError

Bases: AIPerfError

Exception raised when a factory encounters an error while creating a class.

Source code in aiperf/common/exceptions.py
65
66
class FactoryCreationError(AIPerfError):
    """Exception raised when a factory encounters an error while creating a class."""

InferenceClientError

Bases: AIPerfError

Exception raised when a inference client encounters an error.

Source code in aiperf/common/exceptions.py
73
74
class InferenceClientError(AIPerfError):
    """Exception raised when a inference client encounters an error."""

InitializationError

Bases: AIPerfError

Exception raised when something fails to initialize.

Source code in aiperf/common/exceptions.py
69
70
class InitializationError(AIPerfError):
    """Exception raised when something fails to initialize."""

InvalidOperationError

Bases: AIPerfError

Exception raised when an operation is invalid.

Source code in aiperf/common/exceptions.py
77
78
class InvalidOperationError(AIPerfError):
    """Exception raised when an operation is invalid."""

InvalidPayloadError

Bases: InferenceClientError

Exception raised when a inference client receives an invalid payload.

Source code in aiperf/common/exceptions.py
81
82
class InvalidPayloadError(InferenceClientError):
    """Exception raised when a inference client receives an invalid payload."""

InvalidStateError

Bases: AIPerfError

Exception raised when something is in an invalid state.

Source code in aiperf/common/exceptions.py
85
86
class InvalidStateError(AIPerfError):
    """Exception raised when something is in an invalid state."""

MetricTypeError

Bases: AIPerfError

Exception raised when a metric type encounters an error while creating a class.

Source code in aiperf/common/exceptions.py
89
90
class MetricTypeError(AIPerfError):
    """Exception raised when a metric type encounters an error while creating a class."""

MetricUnitError

Bases: AIPerfError

Exception raised when trying to convert a metric to or from a unit that is does not support it.

Source code in aiperf/common/exceptions.py
93
94
class MetricUnitError(AIPerfError):
    """Exception raised when trying to convert a metric to or from a unit that is does not support it."""

NotFoundError

Bases: AIPerfError

Exception raised when something is not found or not available.

Source code in aiperf/common/exceptions.py
97
98
class NotFoundError(AIPerfError):
    """Exception raised when something is not found or not available."""

NotInitializedError

Bases: AIPerfError

Exception raised when something that should be initialized is not.

Source code in aiperf/common/exceptions.py
101
102
class NotInitializedError(AIPerfError):
    """Exception raised when something that should be initialized is not."""

ProxyError

Bases: AIPerfError

Exception raised when a proxy encounters an error.

Source code in aiperf/common/exceptions.py
105
106
class ProxyError(AIPerfError):
    """Exception raised when a proxy encounters an error."""

ServiceError

Bases: AIPerfError

Generic service error.

Source code in aiperf/common/exceptions.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class ServiceError(AIPerfError):
    """Generic service error."""

    def __init__(
        self,
        message: str,
        service_type: "ServiceTypeT",
        service_id: str,
    ) -> None:
        super().__init__(
            f"{message} for service of type {service_type} with id {service_id}"
        )
        self.service_type = service_type
        self.service_id = service_id

ShutdownError

Bases: AIPerfError

Exception raised when a service encounters an error while shutting down.

Source code in aiperf/common/exceptions.py
109
110
class ShutdownError(AIPerfError):
    """Exception raised when a service encounters an error while shutting down."""

UnsupportedHookError

Bases: AIPerfError

Exception raised when a hook is defined on a class that does not have any base classes that provide that hook type.

Source code in aiperf/common/exceptions.py
113
114
class UnsupportedHookError(AIPerfError):
    """Exception raised when a hook is defined on a class that does not have any base classes that provide that hook type."""

ValidationError

Bases: AIPerfError

Exception raised when something fails validation.

Source code in aiperf/common/exceptions.py
117
118
class ValidationError(AIPerfError):
    """Exception raised when something fails validation."""

aiperf.common.factories

AIPerfFactory

Bases: Generic[ClassEnumT, ClassProtocolT]

Defines a custom factory for AIPerf components.

This class is used to create a factory for a given class type and protocol.

Example:

    # Define a new enum for the expected implementation types
    # This is optional, but recommended for type safety.
    class DatasetLoaderType(CaseInsensitiveStrEnum):
        FILE = "file"
        S3 = "s3"

    # Define a new class protocol.
    class DatasetLoaderProtocol(Protocol):
        def load(self) -> Dataset:
            pass

    # Create a new factory for a given class type and protocol.
    class DatasetFactory(FactoryMixin[DatasetLoaderType, DatasetLoaderProtocol]):
        pass

    # Register a new class type mapping to its corresponding class. It should implement the class protocol.
    @DatasetFactory.register(DatasetLoaderType.FILE)
    class FileDatasetLoader:
        def __init__(self, filename: str):
            self.filename = filename

        def load(self) -> Dataset:
            return Dataset.from_file(self.filename)

    DatasetConfig = {
        "type": DatasetLoaderType.FILE,
        "filename": "data.csv"
    }

    # Create a new instance of the class.
    if DatasetConfig["type"] == DatasetLoaderType.FILE:
        dataset_instance = DatasetFactory.create_instance(DatasetLoaderType.FILE, filename=DatasetConfig["filename"])
    else:
        raise ValueError(f"Unsupported dataset loader type: {DatasetConfig['type']}")

    dataset_instance.load()
Source code in aiperf/common/factories.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class AIPerfFactory(Generic[ClassEnumT, ClassProtocolT]):
    """Defines a custom factory for AIPerf components.

    This class is used to create a factory for a given class type and protocol.

    Example:
    ```python
        # Define a new enum for the expected implementation types
        # This is optional, but recommended for type safety.
        class DatasetLoaderType(CaseInsensitiveStrEnum):
            FILE = "file"
            S3 = "s3"

        # Define a new class protocol.
        class DatasetLoaderProtocol(Protocol):
            def load(self) -> Dataset:
                pass

        # Create a new factory for a given class type and protocol.
        class DatasetFactory(FactoryMixin[DatasetLoaderType, DatasetLoaderProtocol]):
            pass

        # Register a new class type mapping to its corresponding class. It should implement the class protocol.
        @DatasetFactory.register(DatasetLoaderType.FILE)
        class FileDatasetLoader:
            def __init__(self, filename: str):
                self.filename = filename

            def load(self) -> Dataset:
                return Dataset.from_file(self.filename)

        DatasetConfig = {
            "type": DatasetLoaderType.FILE,
            "filename": "data.csv"
        }

        # Create a new instance of the class.
        if DatasetConfig["type"] == DatasetLoaderType.FILE:
            dataset_instance = DatasetFactory.create_instance(DatasetLoaderType.FILE, filename=DatasetConfig["filename"])
        else:
            raise ValueError(f"Unsupported dataset loader type: {DatasetConfig['type']}")

        dataset_instance.load()
    ```
    """

    _logger: AIPerfLogger
    _registry: dict[ClassEnumT | str, type[ClassProtocolT]]
    _override_priorities: dict[ClassEnumT | str, int]

    def __init_subclass__(cls) -> None:
        cls._registry = {}
        cls._override_priorities = {}
        cls._logger = AIPerfLogger(cls.__name__)
        super().__init_subclass__()

    @classmethod
    def register_all(
        cls, *class_types: ClassEnumT | str, override_priority: int = 0
    ) -> Callable:
        """Register multiple class types mapping to a single corresponding class.
        This is useful if a single class implements multiple types. Currently only supports
        registering as a single override priority for all types."""

        def decorator(class_cls: type[ClassProtocolT]) -> type[ClassProtocolT]:
            for class_type in class_types:
                cls.register(class_type, override_priority)(class_cls)
            return class_cls

        return decorator

    @classmethod
    def register(
        cls, class_type: ClassEnumT | str, override_priority: int = 0
    ) -> Callable:
        """Register a new class type mapping to its corresponding class.

        Args:
            class_type: The type of class to register
            override_priority: The priority of the override. The higher the priority,
                the more precedence the override has when multiple classes are registered
                for the same class type. Built-in classes have a priority of 0.

        Returns:
            Decorator for the class that implements the class protocol
        """

        def decorator(class_cls: type[ClassProtocolT]) -> type[ClassProtocolT]:
            existing_priority = cls._override_priorities.get(class_type, -1)
            if class_type in cls._registry and existing_priority >= override_priority:
                cls._logger.warning(
                    f"{class_type!r} class {cls._registry[class_type].__name__} already registered with same or higher priority "
                    f"({existing_priority}). The new registration of class {class_cls.__name__} with priority "
                    f"{override_priority} will be ignored.",
                )
                return class_cls

            if class_type not in cls._registry:
                cls._logger.debug(
                    lambda: f"{class_type!r} class {class_cls.__name__} registered with priority {override_priority}.",
                )
            else:
                cls._logger.warning(
                    f"{class_type!r} class {class_cls.__name__} with priority {override_priority} overrides "
                    f"already registered class {cls._registry[class_type].__name__} with lower priority ({existing_priority}).",
                )
            cls._registry[class_type] = class_cls
            cls._override_priorities[class_type] = override_priority
            return class_cls

        return decorator

    @classmethod
    def create_instance(
        cls,
        class_type: ClassEnumT | str,
        **kwargs: Any,
    ) -> ClassProtocolT:
        """Create a new class instance.

        Args:
            class_type: The type of class to create
            **kwargs: Additional arguments for the class

        Returns:
            The created class instance

        Raises:
            FactoryCreationError: If the class type is not registered or there is an error creating the instance
        """
        if class_type not in cls._registry:
            raise FactoryCreationError(
                f"No implementation registered for {class_type!r} in {cls.__name__}."
            )
        try:
            return cls._registry[class_type](**kwargs)
        except Exception as e:
            raise FactoryCreationError(
                f"Error creating {class_type!r} instance for {cls.__name__}: {e}"
            ) from e

    @classmethod
    def get_class_from_type(cls, class_type: ClassEnumT | str) -> type[ClassProtocolT]:
        """Get the class from a class type.

        Args:
            class_type: The class type to get the class from

        Returns:
            The class for the given class type

        Raises:
            TypeError: If the class type is not registered
        """
        if class_type not in cls._registry:
            raise TypeError(
                f"No class found for {class_type!r}. Please register the class first."
            )
        return cls._registry[class_type]

    @classmethod
    def get_all_classes(cls) -> list[type[ClassProtocolT]]:
        """Get all registered classes.

        Returns:
            A list of all registered class types implementing the expected protocol
        """
        return list(cls._registry.values())

    @classmethod
    def get_all_class_types(cls) -> list[ClassEnumT | str]:
        """Get all registered class types."""
        return list(cls._registry.keys())

    @classmethod
    def get_all_classes_and_types(
        cls,
    ) -> list[tuple[type[ClassProtocolT], ClassEnumT | str]]:
        """Get all registered classes and their corresponding class types."""
        return [(cls, class_type) for class_type, cls in cls._registry.items()]

create_instance(class_type, **kwargs) classmethod

Create a new class instance.

Parameters:

Name Type Description Default
class_type ClassEnumT | str

The type of class to create

required
**kwargs Any

Additional arguments for the class

{}

Returns:

Type Description
ClassProtocolT

The created class instance

Raises:

Type Description
FactoryCreationError

If the class type is not registered or there is an error creating the instance

Source code in aiperf/common/factories.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
@classmethod
def create_instance(
    cls,
    class_type: ClassEnumT | str,
    **kwargs: Any,
) -> ClassProtocolT:
    """Create a new class instance.

    Args:
        class_type: The type of class to create
        **kwargs: Additional arguments for the class

    Returns:
        The created class instance

    Raises:
        FactoryCreationError: If the class type is not registered or there is an error creating the instance
    """
    if class_type not in cls._registry:
        raise FactoryCreationError(
            f"No implementation registered for {class_type!r} in {cls.__name__}."
        )
    try:
        return cls._registry[class_type](**kwargs)
    except Exception as e:
        raise FactoryCreationError(
            f"Error creating {class_type!r} instance for {cls.__name__}: {e}"
        ) from e

get_all_class_types() classmethod

Get all registered class types.

Source code in aiperf/common/factories.py
233
234
235
236
@classmethod
def get_all_class_types(cls) -> list[ClassEnumT | str]:
    """Get all registered class types."""
    return list(cls._registry.keys())

get_all_classes() classmethod

Get all registered classes.

Returns:

Type Description
list[type[ClassProtocolT]]

A list of all registered class types implementing the expected protocol

Source code in aiperf/common/factories.py
224
225
226
227
228
229
230
231
@classmethod
def get_all_classes(cls) -> list[type[ClassProtocolT]]:
    """Get all registered classes.

    Returns:
        A list of all registered class types implementing the expected protocol
    """
    return list(cls._registry.values())

get_all_classes_and_types() classmethod

Get all registered classes and their corresponding class types.

Source code in aiperf/common/factories.py
238
239
240
241
242
243
@classmethod
def get_all_classes_and_types(
    cls,
) -> list[tuple[type[ClassProtocolT], ClassEnumT | str]]:
    """Get all registered classes and their corresponding class types."""
    return [(cls, class_type) for class_type, cls in cls._registry.items()]

get_class_from_type(class_type) classmethod

Get the class from a class type.

Parameters:

Name Type Description Default
class_type ClassEnumT | str

The class type to get the class from

required

Returns:

Type Description
type[ClassProtocolT]

The class for the given class type

Raises:

Type Description
TypeError

If the class type is not registered

Source code in aiperf/common/factories.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
@classmethod
def get_class_from_type(cls, class_type: ClassEnumT | str) -> type[ClassProtocolT]:
    """Get the class from a class type.

    Args:
        class_type: The class type to get the class from

    Returns:
        The class for the given class type

    Raises:
        TypeError: If the class type is not registered
    """
    if class_type not in cls._registry:
        raise TypeError(
            f"No class found for {class_type!r}. Please register the class first."
        )
    return cls._registry[class_type]

register(class_type, override_priority=0) classmethod

Register a new class type mapping to its corresponding class.

Parameters:

Name Type Description Default
class_type ClassEnumT | str

The type of class to register

required
override_priority int

The priority of the override. The higher the priority, the more precedence the override has when multiple classes are registered for the same class type. Built-in classes have a priority of 0.

0

Returns:

Type Description
Callable

Decorator for the class that implements the class protocol

Source code in aiperf/common/factories.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@classmethod
def register(
    cls, class_type: ClassEnumT | str, override_priority: int = 0
) -> Callable:
    """Register a new class type mapping to its corresponding class.

    Args:
        class_type: The type of class to register
        override_priority: The priority of the override. The higher the priority,
            the more precedence the override has when multiple classes are registered
            for the same class type. Built-in classes have a priority of 0.

    Returns:
        Decorator for the class that implements the class protocol
    """

    def decorator(class_cls: type[ClassProtocolT]) -> type[ClassProtocolT]:
        existing_priority = cls._override_priorities.get(class_type, -1)
        if class_type in cls._registry and existing_priority >= override_priority:
            cls._logger.warning(
                f"{class_type!r} class {cls._registry[class_type].__name__} already registered with same or higher priority "
                f"({existing_priority}). The new registration of class {class_cls.__name__} with priority "
                f"{override_priority} will be ignored.",
            )
            return class_cls

        if class_type not in cls._registry:
            cls._logger.debug(
                lambda: f"{class_type!r} class {class_cls.__name__} registered with priority {override_priority}.",
            )
        else:
            cls._logger.warning(
                f"{class_type!r} class {class_cls.__name__} with priority {override_priority} overrides "
                f"already registered class {cls._registry[class_type].__name__} with lower priority ({existing_priority}).",
            )
        cls._registry[class_type] = class_cls
        cls._override_priorities[class_type] = override_priority
        return class_cls

    return decorator

register_all(*class_types, override_priority=0) classmethod

Register multiple class types mapping to a single corresponding class. This is useful if a single class implements multiple types. Currently only supports registering as a single override priority for all types.

Source code in aiperf/common/factories.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@classmethod
def register_all(
    cls, *class_types: ClassEnumT | str, override_priority: int = 0
) -> Callable:
    """Register multiple class types mapping to a single corresponding class.
    This is useful if a single class implements multiple types. Currently only supports
    registering as a single override priority for all types."""

    def decorator(class_cls: type[ClassProtocolT]) -> type[ClassProtocolT]:
        for class_type in class_types:
            cls.register(class_type, override_priority)(class_cls)
        return class_cls

    return decorator

AIPerfSingletonFactory

Bases: AIPerfFactory[ClassEnumT, ClassProtocolT]

Factory for registering and creating singleton instances of a given class type and protocol. This factory is useful for creating instances that are shared across the application, such as communication clients. Calling create_instance will create a new instance if it doesn't exist, otherwise it will return the existing instance. Calling get_instance will return the existing instance if it exists, otherwise it will raise an error. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
class AIPerfSingletonFactory(AIPerfFactory[ClassEnumT, ClassProtocolT]):
    """Factory for registering and creating singleton instances of a given class type and protocol.
    This factory is useful for creating instances that are shared across the application, such as communication clients.
    Calling create_instance will create a new instance if it doesn't exist, otherwise it will return the existing instance.
    Calling get_instance will return the existing instance if it exists, otherwise it will raise an error.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    _instances: dict[ClassEnumT | str, ClassProtocolT]
    _instances_lock: Lock
    _instances_pid: dict[ClassEnumT | str, int]

    def __init_subclass__(cls) -> None:
        cls._instances = {}
        cls._instances_lock = Lock()
        cls._instances_pid = {}
        super().__init_subclass__()

    @classmethod
    def set_instance(
        cls, class_type: ClassEnumT | str, instance: ClassProtocolT
    ) -> None:
        cls._instances[class_type] = instance

    @classmethod
    def get_or_create_instance(
        cls, class_type: ClassEnumT | str, **kwargs: Any
    ) -> ClassProtocolT:
        """Syntactic sugar for create_instance, but with a more descriptive name for singleton factories."""
        return cls.create_instance(class_type, **kwargs)

    @classmethod
    def create_instance(
        cls, class_type: ClassEnumT | str, **kwargs: Any
    ) -> ClassProtocolT:
        """Create a new instance of the given class type.
        If the instance does not exist, or the process ID has changed, a new instance will be created.
        """
        # TODO: Technically, this this should handle the case where kwargs are different,
        #       but that would require a more complex implementation.
        if (
            class_type not in cls._instances
            or os.getpid() != cls._instances_pid[class_type]
        ):
            cls._logger.debug(
                lambda: f"Creating new instance for {class_type!r} in {cls.__name__}."
            )
            with cls._instances_lock:
                if (
                    class_type not in cls._instances
                    or os.getpid() != cls._instances_pid[class_type]
                ):
                    cls._instances[class_type] = super().create_instance(
                        class_type, **kwargs
                    )
                    cls._instances_pid[class_type] = os.getpid()
                    cls._logger.debug(
                        lambda: f"New instance for {class_type!r} in {cls.__name__} created."
                    )
        else:
            cls._logger.debug(
                lambda: f"Instance for {class_type!r} in {cls.__name__} already exists. Returning existing instance."
            )
        return cls._instances[class_type]

    @classmethod
    def get_instance(cls, class_type: ClassEnumT | str) -> ClassProtocolT:
        if class_type not in cls._instances:
            raise InvalidStateError(
                f"No instance found for {class_type!r} in {cls.__name__}. "
                f"Ensure you call AIPerfSingletonFactory.create_instance({class_type!r}) first."
            )
        if os.getpid() != cls._instances_pid[class_type]:
            raise InvalidStateError(
                f"Instance for {class_type!r} in {cls.__name__} is not valid for the current process. "
                f"Ensure you call AIPerfSingletonFactory.create_instance({class_type!r}) first after forking."
            )
        return cls._instances[class_type]

    @classmethod
    def get_all_instances(cls) -> dict[ClassEnumT | str, ClassProtocolT]:
        return cls._instances

    @classmethod
    def has_instance(cls, class_type: ClassEnumT | str) -> bool:
        return class_type in cls._instances

create_instance(class_type, **kwargs) classmethod

Create a new instance of the given class type. If the instance does not exist, or the process ID has changed, a new instance will be created.

Source code in aiperf/common/factories.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
@classmethod
def create_instance(
    cls, class_type: ClassEnumT | str, **kwargs: Any
) -> ClassProtocolT:
    """Create a new instance of the given class type.
    If the instance does not exist, or the process ID has changed, a new instance will be created.
    """
    # TODO: Technically, this this should handle the case where kwargs are different,
    #       but that would require a more complex implementation.
    if (
        class_type not in cls._instances
        or os.getpid() != cls._instances_pid[class_type]
    ):
        cls._logger.debug(
            lambda: f"Creating new instance for {class_type!r} in {cls.__name__}."
        )
        with cls._instances_lock:
            if (
                class_type not in cls._instances
                or os.getpid() != cls._instances_pid[class_type]
            ):
                cls._instances[class_type] = super().create_instance(
                    class_type, **kwargs
                )
                cls._instances_pid[class_type] = os.getpid()
                cls._logger.debug(
                    lambda: f"New instance for {class_type!r} in {cls.__name__} created."
                )
    else:
        cls._logger.debug(
            lambda: f"Instance for {class_type!r} in {cls.__name__} already exists. Returning existing instance."
        )
    return cls._instances[class_type]

get_or_create_instance(class_type, **kwargs) classmethod

Syntactic sugar for create_instance, but with a more descriptive name for singleton factories.

Source code in aiperf/common/factories.py
270
271
272
273
274
275
@classmethod
def get_or_create_instance(
    cls, class_type: ClassEnumT | str, **kwargs: Any
) -> ClassProtocolT:
    """Syntactic sugar for create_instance, but with a more descriptive name for singleton factories."""
    return cls.create_instance(class_type, **kwargs)

CommunicationClientFactory

Bases: AIPerfFactory[CommClientType, 'CommunicationClientProtocol']

Factory for registering and creating CommunicationClientProtocol instances based on the specified communication client type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
class CommunicationClientFactory(
    AIPerfFactory[CommClientType, "CommunicationClientProtocol"]
):
    """Factory for registering and creating CommunicationClientProtocol instances based on the specified communication client type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: CommClientType | str,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        **kwargs,
    ) -> "CommunicationClientProtocol":
        return super().create_instance(
            class_type, address=address, bind=bind, socket_ops=socket_ops, **kwargs
        )

CommunicationFactory

Bases: AIPerfSingletonFactory[CommunicationBackend, 'CommunicationProtocol']

Factory for registering and creating CommunicationProtocol instances based on the specified communication backend. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class CommunicationFactory(
    AIPerfSingletonFactory[CommunicationBackend, "CommunicationProtocol"]
):
    """Factory for registering and creating CommunicationProtocol instances based on the specified communication backend.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: CommunicationBackend | str,
        config: "BaseZMQCommunicationConfig",
        **kwargs,
    ) -> "CommunicationProtocol":
        return super().create_instance(class_type, config=config, **kwargs)

ComposerFactory

Bases: AIPerfFactory[ComposerType, 'BaseDatasetComposer']

Factory for registering and creating BaseDatasetComposer instances based on the specified composer type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
372
373
374
375
376
377
378
379
380
381
382
383
class ComposerFactory(AIPerfFactory[ComposerType, "BaseDatasetComposer"]):
    """Factory for registering and creating BaseDatasetComposer instances based on the specified composer type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: ComposerType | str,
        **kwargs,
    ) -> "BaseDatasetComposer":
        return super().create_instance(class_type, **kwargs)

CustomDatasetFactory

Bases: AIPerfFactory[CustomDatasetType, 'CustomDatasetLoaderProtocol']

Factory for registering and creating CustomDatasetLoaderProtocol instances based on the specified custom dataset type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
class CustomDatasetFactory(
    AIPerfFactory[CustomDatasetType, "CustomDatasetLoaderProtocol"]
):
    """Factory for registering and creating CustomDatasetLoaderProtocol instances based on the specified custom dataset type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: CustomDatasetType | str,
        **kwargs,
    ) -> "CustomDatasetLoaderProtocol":
        return super().create_instance(class_type, **kwargs)

DataExporterFactory

Bases: AIPerfFactory[DataExporterType, 'DataExporterProtocol']

Factory for registering and creating DataExporterProtocol instances based on the specified data exporter type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
class DataExporterFactory(AIPerfFactory[DataExporterType, "DataExporterProtocol"]):
    """Factory for registering and creating DataExporterProtocol instances based on the specified data exporter type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: DataExporterType | str,
        exporter_config: "ExporterConfig",
        **kwargs,
    ) -> "DataExporterProtocol":
        return super().create_instance(
            class_type, exporter_config=exporter_config, **kwargs
        )

InferenceClientFactory

Bases: AIPerfFactory[EndpointType, 'InferenceClientProtocol']

Factory for registering and creating InferenceClientProtocol instances based on the specified endpoint type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
class InferenceClientFactory(AIPerfFactory[EndpointType, "InferenceClientProtocol"]):
    """Factory for registering and creating InferenceClientProtocol instances based on the specified endpoint type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: EndpointType | str,
        model_endpoint: "ModelEndpointInfo",
        **kwargs,
    ) -> "InferenceClientProtocol":
        return super().create_instance(
            class_type, model_endpoint=model_endpoint, **kwargs
        )

RecordProcessorFactory

Bases: AIPerfFactory[RecordProcessorType, 'RecordProcessorProtocol']

Factory for registering and creating RecordProcessorProtocol instances based on the specified record processor type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
class RecordProcessorFactory(
    AIPerfFactory[RecordProcessorType, "RecordProcessorProtocol"]
):
    """Factory for registering and creating RecordProcessorProtocol instances based on the specified record processor type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: RecordProcessorType | str,
        service_config: "ServiceConfig",
        user_config: "UserConfig",
        **kwargs,
    ) -> "RecordProcessorProtocol":
        return super().create_instance(
            class_type,
            service_config=service_config,
            user_config=user_config,
            **kwargs,
        )

RequestConverterFactory

Bases: AIPerfSingletonFactory[EndpointType, 'RequestConverterProtocol']

Factory for registering and creating RequestConverterProtocol instances based on the specified request payload type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
436
437
438
439
440
441
class RequestConverterFactory(
    AIPerfSingletonFactory[EndpointType, "RequestConverterProtocol"]
):
    """Factory for registering and creating RequestConverterProtocol instances based on the specified request payload type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

ResponseExtractorFactory

Bases: AIPerfFactory[EndpointType, 'ResponseExtractorProtocol']

Factory for registering and creating ResponseExtractorProtocol instances based on the specified response extractor type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
class ResponseExtractorFactory(
    AIPerfFactory[EndpointType, "ResponseExtractorProtocol"]
):
    """Factory for registering and creating ResponseExtractorProtocol instances based on the specified response extractor type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: EndpointType | str,
        model_endpoint: "ModelEndpointInfo",
        **kwargs,
    ) -> "ResponseExtractorProtocol":
        return super().create_instance(
            class_type, model_endpoint=model_endpoint, **kwargs
        )

ResultsProcessorFactory

Bases: AIPerfFactory[ResultsProcessorType, 'ResultsProcessorProtocol']

Factory for registering and creating ResultsProcessorProtocol instances based on the specified results processor type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
class ResultsProcessorFactory(
    AIPerfFactory[ResultsProcessorType, "ResultsProcessorProtocol"]
):
    """Factory for registering and creating ResultsProcessorProtocol instances based on the specified results processor type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: ResultsProcessorType | str,
        service_config: "ServiceConfig",
        user_config: "UserConfig",
        **kwargs,
    ) -> "ResultsProcessorProtocol":
        return super().create_instance(
            class_type,
            service_config=service_config,
            user_config=user_config,
            **kwargs,
        )

ServiceFactory

Bases: AIPerfFactory[ServiceType, 'ServiceProtocol']

Factory for registering and creating ServiceProtocol instances based on the specified service type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
class ServiceFactory(AIPerfFactory[ServiceType, "ServiceProtocol"]):
    """Factory for registering and creating ServiceProtocol instances based on the specified service type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def register_all(
        cls, *class_types: ServiceTypeT, override_priority: int = 0
    ) -> Callable[..., Any]:
        raise InvalidOperationError(
            "ServiceFactory.register_all is not supported. A single service can only be registered with a single type."
        )

    @classmethod
    def register(
        cls, class_type: ServiceTypeT, override_priority: int = 0
    ) -> Callable[..., Any]:
        # Override the register method to set the service_type on the class
        original_decorator = super().register(class_type, override_priority)

        def decorator(class_cls: type[ServiceProtocolT]) -> type[ServiceProtocolT]:
            class_cls.service_type = class_type
            original_decorator(class_cls)
            return class_cls

        return decorator

ServiceManagerFactory

Bases: AIPerfFactory[ServiceRunType, 'ServiceManagerProtocol']

Factory for registering and creating ServiceManagerProtocol instances based on the specified service run type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
class ServiceManagerFactory(AIPerfFactory[ServiceRunType, "ServiceManagerProtocol"]):
    """Factory for registering and creating ServiceManagerProtocol instances based on the specified service run type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: ServiceRunType | str,
        required_services: dict[ServiceTypeT, int],
        service_config: "ServiceConfig",
        user_config: "UserConfig",
        **kwargs,
    ) -> "ServiceManagerProtocol":
        return super().create_instance(
            class_type,
            required_services=required_services,
            service_config=service_config,
            user_config=user_config,
            **kwargs,
        )

ZMQProxyFactory

Bases: AIPerfFactory[ZMQProxyType, 'BaseZMQProxy']

Factory for registering and creating BaseZMQProxy instances based on the specified ZMQ proxy type. see: :class:aiperf.common.factories.AIPerfFactory for more details.

Source code in aiperf/common/factories.py
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
class ZMQProxyFactory(AIPerfFactory[ZMQProxyType, "BaseZMQProxy"]):
    """Factory for registering and creating BaseZMQProxy instances based on the specified ZMQ proxy type.
    see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
    """

    @classmethod
    def create_instance(  # type: ignore[override]
        cls,
        class_type: ZMQProxyType | str,
        zmq_proxy_config: "BaseZMQProxyConfig",
        **kwargs,
    ) -> "BaseZMQProxy":
        return super().create_instance(
            class_type, zmq_proxy_config=zmq_proxy_config, **kwargs
        )

aiperf.common.hooks

This module provides an extensive set of hook definitions for AIPerf. It is designed to be used in conjunction with the :class:HooksMixin for classes to provide support for hooks. It provides a simple interface for registering hooks.

Classes should inherit from the :class:HooksMixin, and specify the provided hook types by decorating the class with the :func:provides_hooks decorator.

The hook functions are registered by decorating functions with the various hook decorators such as :func:on_init, :func:on_start, :func:on_stop, etc.

More than one hook can be registered for a given hook type, and classes that inherit from classes with existing hooks will inherit the hooks from the base classes as well.

The hooks are run by calling the :meth:HooksMixin.run_hooks method or retrieved via the :meth:HooksMixin.get_hooks method on the class.

HookType = AIPerfHook | str module-attribute

Type alias for valid hook types. This is a union of the AIPerfHook enum and any user-defined custom strings.

Hook

Bases: BaseModel, Generic[HookParamsT]

A hook is a function that is decorated with a hook type and optional parameters. The HookParamsT is the type of the parameters. You can either have a static value, or a callable that returns the parameters.

Source code in aiperf/common/hooks.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class Hook(BaseModel, Generic[HookParamsT]):
    """A hook is a function that is decorated with a hook type and optional parameters.
    The HookParamsT is the type of the parameters. You can either have a static value,
    or a callable that returns the parameters.
    """

    func: Callable
    params: HookParamsT | Callable[[SelfT], HookParamsT] | None = None  # type: ignore

    @property
    def hook_type(self) -> HookType:
        return getattr(self.func, HookAttrs.HOOK_TYPE)

    @property
    def func_name(self) -> str:
        return self.func.__name__

    @property
    def qualified_name(self) -> str:
        return f"{self.func.__qualname__}"

    def resolve_params(self, self_obj: SelfT) -> HookParamsT | None:
        """Resolve the parameters for the hook. If the parameters are a callable, it will be called
        with the self_obj as the argument, otherwise the parameters are returned as is."""
        if self.params is None:
            return None
        # With variable length parameters, you get a tuple with 1 item in it, so we need to check for that.
        if (
            isinstance(self.params, Iterable)
            and len(self.params) == 1
            and callable(self.params[0])
        ):  # type: ignore
            return self.params[0](self_obj)  # type: ignore
        if callable(self.params):
            return self.params(self_obj)
        return self.params  # type: ignore

    async def __call__(self, **kwargs) -> None:
        if asyncio.iscoroutinefunction(self.func):
            await self.func(**kwargs)
        else:
            await asyncio.to_thread(self.func, **kwargs)

    def __str__(self) -> str:
        return f"{self.hook_type} 🡒 {self.qualified_name}"

resolve_params(self_obj)

Resolve the parameters for the hook. If the parameters are a callable, it will be called with the self_obj as the argument, otherwise the parameters are returned as is.

Source code in aiperf/common/hooks.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def resolve_params(self, self_obj: SelfT) -> HookParamsT | None:
    """Resolve the parameters for the hook. If the parameters are a callable, it will be called
    with the self_obj as the argument, otherwise the parameters are returned as is."""
    if self.params is None:
        return None
    # With variable length parameters, you get a tuple with 1 item in it, so we need to check for that.
    if (
        isinstance(self.params, Iterable)
        and len(self.params) == 1
        and callable(self.params[0])
    ):  # type: ignore
        return self.params[0](self_obj)  # type: ignore
    if callable(self.params):
        return self.params(self_obj)
    return self.params  # type: ignore

HookAttrs

Constant attribute names for hooks.

When you decorate a function with a hook decorator, the hook type and parameters are set as attributes on the function or class.

Source code in aiperf/common/hooks.py
57
58
59
60
61
62
63
64
65
66
class HookAttrs:
    """Constant attribute names for hooks.

    When you decorate a function with a hook decorator, the hook type and parameters are
    set as attributes on the function or class.
    """

    HOOK_TYPE = "__aiperf_hook_type__"
    HOOK_PARAMS = "__aiperf_hook_params__"
    PROVIDES_HOOKS = "__provides_hooks__"

background_task(interval=None, immediate=True, stop_on_error=False)

Decorator to mark a method as a background task with automatic management.

Tasks are automatically started when the service starts and stopped when the service stops. The decorated method will be run periodically in the background when the service is running.

Parameters:

Name Type Description Default
interval float | Callable[[SelfT], float] | None

Time between task executions in seconds. If None, the task will run once. Can be a callable that returns the interval, and will be called with 'self' as the argument.

None
immediate bool

If True, run the task immediately on start, otherwise wait for the interval first.

True
stop_on_error bool

If True, stop the task on any exception, otherwise log and continue.

False

Example:

class MyPlugin(AIPerfLifecycleMixin):
    @background_task(interval=1.0)
    def _background_task(self) -> None:
        pass

The above is the equivalent to setting:

MyPlugin._background_task.__aiperf_hook_type__ = AIPerfHook.BACKGROUND_TASK
MyPlugin._background_task.__aiperf_hook_params__ = BackgroundTaskParams(
    interval=1.0, immediate=True, stop_on_error=False
)
Source code in aiperf/common/hooks.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def background_task(
    interval: float | Callable[[SelfT], float] | None = None,
    immediate: bool = True,
    stop_on_error: bool = False,
) -> Callable:
    """
    Decorator to mark a method as a background task with automatic management.

    Tasks are automatically started when the service starts and stopped when the service stops.
    The decorated method will be run periodically in the background when the service is running.

    Args:
        interval: Time between task executions in seconds. If None, the task will run once.
            Can be a callable that returns the interval, and will be called with 'self' as the argument.
        immediate: If True, run the task immediately on start, otherwise wait for the interval first.
        stop_on_error: If True, stop the task on any exception, otherwise log and continue.

    Example:
    ```python
    class MyPlugin(AIPerfLifecycleMixin):
        @background_task(interval=1.0)
        def _background_task(self) -> None:
            pass
    ```

    The above is the equivalent to setting:
    ```python
    MyPlugin._background_task.__aiperf_hook_type__ = AIPerfHook.BACKGROUND_TASK
    MyPlugin._background_task.__aiperf_hook_params__ = BackgroundTaskParams(
        interval=1.0, immediate=True, stop_on_error=False
    )
    ```
    """
    return _hook_decorator_with_params(
        AIPerfHook.BACKGROUND_TASK,
        BackgroundTaskParams(
            interval=interval, immediate=immediate, stop_on_error=stop_on_error
        ),
    )

on_command(*command_types)

Decorator to specify that the function is a hook that should be called when a CommandMessage with the given command type(s) is received from the message bus. See :func:aiperf.common.hooks._hook_decorator_for_message_types.

Example:

class MyService(BaseComponentService):
    @on_command(CommandType.PROFILE_START)
    def _on_profile_start(self, message: ProfileStartCommand) -> CommandResponse:
        pass

The above is the equivalent to setting:

MyService._on_profile_start.__aiperf_hook_type__ = AIPerfHook.ON_COMMAND
MyService._on_profile_start.__aiperf_hook_params__ = (CommandType.PROFILE_START,)
Source code in aiperf/common/hooks.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def on_command(
    *command_types: CommandTypeT | Callable[[SelfT], Iterable[CommandTypeT]],
) -> Callable:
    """Decorator to specify that the function is a hook that should be called when a CommandMessage with the given
    command type(s) is received from the message bus.
    See :func:`aiperf.common.hooks._hook_decorator_for_message_types`.

    Example:
    ```python
    class MyService(BaseComponentService):
        @on_command(CommandType.PROFILE_START)
        def _on_profile_start(self, message: ProfileStartCommand) -> CommandResponse:
            pass
    ```

    The above is the equivalent to setting:
    ```python
    MyService._on_profile_start.__aiperf_hook_type__ = AIPerfHook.ON_COMMAND
    MyService._on_profile_start.__aiperf_hook_params__ = (CommandType.PROFILE_START,)
    ```
    """
    return _hook_decorator_with_params(AIPerfHook.ON_COMMAND, command_types)

on_init(func)

Decorator to specify that the function is a hook that should be called during initialization. See :func:aiperf.common.hooks._hook_decorator.

Example:

class MyPlugin(AIPerfLifecycleMixin):
    @on_init
    def _init_plugin(self) -> None:
        pass

The above is the equivalent to setting:

MyPlugin._init_plugin.__aiperf_hook_type__ = AIPerfHook.ON_INIT
Source code in aiperf/common/hooks.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def on_init(func: Callable) -> Callable:
    """Decorator to specify that the function is a hook that should be called during initialization.
    See :func:`aiperf.common.hooks._hook_decorator`.

    Example:
    ```python
    class MyPlugin(AIPerfLifecycleMixin):
        @on_init
        def _init_plugin(self) -> None:
            pass
    ```

    The above is the equivalent to setting:
    ```python
    MyPlugin._init_plugin.__aiperf_hook_type__ = AIPerfHook.ON_INIT
    ```
    """
    return _hook_decorator(AIPerfHook.ON_INIT, func)

on_message(*message_types)

Decorator to specify that the function is a hook that should be called when messages of the given type(s) (or topics) are received from the message bus. See :func:aiperf.common.hooks._hook_decorator_with_params.

Example:

class MyService(MessageBusClientMixin):
    @on_message(MessageType.STATUS)
    def _on_status_message(self, message: StatusMessage) -> None:
        pass

The above is the equivalent to setting:

MyService._on_status_message.__aiperf_hook_type__ = AIPerfHook.ON_MESSAGE
MyService._on_status_message.__aiperf_hook_params__ = (MessageType.STATUS,)
Source code in aiperf/common/hooks.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def on_message(
    *message_types: MessageTypeT | Callable[[SelfT], Iterable[MessageTypeT]],
) -> Callable:
    """Decorator to specify that the function is a hook that should be called when messages of the
    given type(s) (or topics) are received from the message bus.
    See :func:`aiperf.common.hooks._hook_decorator_with_params`.

    Example:
    ```python
    class MyService(MessageBusClientMixin):
        @on_message(MessageType.STATUS)
        def _on_status_message(self, message: StatusMessage) -> None:
            pass
    ```

    The above is the equivalent to setting:
    ```python
    MyService._on_status_message.__aiperf_hook_type__ = AIPerfHook.ON_MESSAGE
    MyService._on_status_message.__aiperf_hook_params__ = (MessageType.STATUS,)
    ```
    """
    return _hook_decorator_with_params(AIPerfHook.ON_MESSAGE, message_types)

on_pull_message(*message_types)

Decorator to specify that the function is a hook that should be called a pull client receives a message of the given type(s). See :func:aiperf.common.hooks._hook_decorator_for_message_types.

Example:

class MyService(PullClientMixin, BaseComponentService):
    @on_pull_message(MessageType.CREDIT_DROP)
    def _on_credit_drop_pull(self, message: CreditDropMessage) -> None:
        pass

The above is the equivalent to setting: ```python MyService._on_pull_message.aiperf_hook_type = AIPerfHook.ON_PULL_MESSAGE MyService._on_pull_message.aiperf_hook_params = (MessageType.CREDIT_DROP,)

Source code in aiperf/common/hooks.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def on_pull_message(
    *message_types: MessageTypeT | Callable[[SelfT], Iterable[MessageTypeT]],
) -> Callable:
    """Decorator to specify that the function is a hook that should be called a pull client
    receives a message of the given type(s).
    See :func:`aiperf.common.hooks._hook_decorator_for_message_types`.

    Example:
    ```python
    class MyService(PullClientMixin, BaseComponentService):
        @on_pull_message(MessageType.CREDIT_DROP)
        def _on_credit_drop_pull(self, message: CreditDropMessage) -> None:
            pass
    ```

    The above is the equivalent to setting:
    ```python
    MyService._on_pull_message.__aiperf_hook_type__ = AIPerfHook.ON_PULL_MESSAGE
    MyService._on_pull_message.__aiperf_hook_params__ = (MessageType.CREDIT_DROP,)
    """
    return _hook_decorator_with_params(AIPerfHook.ON_PULL_MESSAGE, message_types)

on_request(*message_types)

Decorator to specify that the function is a hook that should be called when requests of the given type(s) are received from a ReplyClient. See :func:aiperf.common.hooks._hook_decorator_for_message_types.

Example:

class MyService(RequestClientMixin, BaseComponentService):
    @on_request(MessageType.CONVERSATION_REQUEST)
    async def _handle_conversation_request(
        self, message: ConversationRequestMessage
    ) -> ConversationResponseMessage:
        return ConversationResponseMessage(
            ...
        )

The above is the equivalent to setting:

MyService._handle_conversation_request.__aiperf_hook_type__ = AIPerfHook.ON_REQUEST
MyService._handle_conversation_request.__aiperf_hook_params__ = (MessageType.CONVERSATION_REQUEST,)
Source code in aiperf/common/hooks.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def on_request(
    *message_types: MessageTypeT | Callable[[SelfT], Iterable[MessageTypeT]],
) -> Callable:
    """Decorator to specify that the function is a hook that should be called when requests of the
    given type(s) are received from a ReplyClient.
    See :func:`aiperf.common.hooks._hook_decorator_for_message_types`.

    Example:
    ```python
    class MyService(RequestClientMixin, BaseComponentService):
        @on_request(MessageType.CONVERSATION_REQUEST)
        async def _handle_conversation_request(
            self, message: ConversationRequestMessage
        ) -> ConversationResponseMessage:
            return ConversationResponseMessage(
                ...
            )
    ```

    The above is the equivalent to setting:
    ```python
    MyService._handle_conversation_request.__aiperf_hook_type__ = AIPerfHook.ON_REQUEST
    MyService._handle_conversation_request.__aiperf_hook_params__ = (MessageType.CONVERSATION_REQUEST,)
    ```
    """
    return _hook_decorator_with_params(AIPerfHook.ON_REQUEST, message_types)

on_start(func)

Decorator to specify that the function is a hook that should be called during start. See :func:aiperf.common.hooks._hook_decorator.

Example:

class MyPlugin(AIPerfLifecycleMixin):
    @on_start
    def _start_plugin(self) -> None:
        pass

The above is the equivalent to setting:

MyPlugin._start_plugin.__aiperf_hook_type__ = AIPerfHook.ON_START
Source code in aiperf/common/hooks.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def on_start(func: Callable) -> Callable:
    """Decorator to specify that the function is a hook that should be called during start.
    See :func:`aiperf.common.hooks._hook_decorator`.

    Example:
    ```python
    class MyPlugin(AIPerfLifecycleMixin):
        @on_start
        def _start_plugin(self) -> None:
            pass
    ```

    The above is the equivalent to setting:
    ```python
    MyPlugin._start_plugin.__aiperf_hook_type__ = AIPerfHook.ON_START
    ```
    """
    return _hook_decorator(AIPerfHook.ON_START, func)

on_state_change(func)

Decorator to specify that the function is a hook that should be called during the service state change. See :func:aiperf.common.hooks._hook_decorator.

Example:

class MyPlugin(AIPerfLifecycleMixin):
    @on_state_change
    def _on_state_change(self, old_state: LifecycleState, new_state: LifecycleState) -> None:
        pass

The above is the equivalent to setting:

MyPlugin._on_state_change.__aiperf_hook_type__ = AIPerfHook.ON_STATE_CHANGE
Source code in aiperf/common/hooks.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def on_state_change(
    func: Callable[["HooksMixinT", LifecycleState, LifecycleState], Awaitable],
) -> Callable[["HooksMixinT", LifecycleState, LifecycleState], Awaitable]:
    """Decorator to specify that the function is a hook that should be called during the service state change.
    See :func:`aiperf.common.hooks._hook_decorator`.

    Example:
    ```python
    class MyPlugin(AIPerfLifecycleMixin):
        @on_state_change
        def _on_state_change(self, old_state: LifecycleState, new_state: LifecycleState) -> None:
            pass
    ```

    The above is the equivalent to setting:
    ```python
    MyPlugin._on_state_change.__aiperf_hook_type__ = AIPerfHook.ON_STATE_CHANGE
    ```
    """
    return _hook_decorator(AIPerfHook.ON_STATE_CHANGE, func)

on_stop(func)

Decorator to specify that the function is a hook that should be called during stop. See :func:aiperf.common.hooks._hook_decorator.

Example:

class MyPlugin(AIPerfLifecycleMixin):
    @on_stop
    def _stop_plugin(self) -> None:
        pass

The above is the equivalent to setting:

MyPlugin._stop_plugin.__aiperf_hook_type__ = AIPerfHook.ON_STOP
Source code in aiperf/common/hooks.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def on_stop(func: Callable) -> Callable:
    """Decorator to specify that the function is a hook that should be called during stop.
    See :func:`aiperf.common.hooks._hook_decorator`.

    Example:
    ```python
    class MyPlugin(AIPerfLifecycleMixin):
        @on_stop
        def _stop_plugin(self) -> None:
            pass
    ```

    The above is the equivalent to setting:
    ```python
    MyPlugin._stop_plugin.__aiperf_hook_type__ = AIPerfHook.ON_STOP
    ```
    """
    return _hook_decorator(AIPerfHook.ON_STOP, func)

provides_hooks(*hook_types)

Decorator to specify that the class provides a hook of the given type to all of its subclasses.

Example:

@provides_hooks(AIPerfHook.ON_MESSAGE)
class MessageBusClientMixin(CommunicationMixin):
    pass

The above is the equivalent to setting:

MessageBusClientMixin.__provides_hooks__ = {AIPerfHook.ON_MESSAGE}
Source code in aiperf/common/hooks.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def provides_hooks(
    *hook_types: HookType,
) -> Callable[[type[HooksMixinT]], type[HooksMixinT]]:
    """Decorator to specify that the class provides a hook of the given type to all of its subclasses.

    Example:
    ```python
    @provides_hooks(AIPerfHook.ON_MESSAGE)
    class MessageBusClientMixin(CommunicationMixin):
        pass
    ```

    The above is the equivalent to setting:
    ```python
    MessageBusClientMixin.__provides_hooks__ = {AIPerfHook.ON_MESSAGE}
    ```
    """

    def decorator(cls: type[HooksMixinT]) -> type[HooksMixinT]:
        setattr(cls, HookAttrs.PROVIDES_HOOKS, set(hook_types))
        return cls

    return decorator

aiperf.common.logging

MultiProcessLogHandler

Bases: RichHandler

Custom logging handler that forwards log records to a multiprocessing queue.

Source code in aiperf/common/logging.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class MultiProcessLogHandler(RichHandler):
    """Custom logging handler that forwards log records to a multiprocessing queue."""

    def __init__(
        self, log_queue: multiprocessing.Queue, service_id: str | None = None
    ) -> None:
        super().__init__()
        self.log_queue = log_queue
        self.service_id = service_id

    def emit(self, record: logging.LogRecord) -> None:
        """Emit a log record to the queue."""
        try:
            # Create a serializable log data structure
            log_data = {
                "name": record.name,
                "levelname": record.levelname,
                "levelno": record.levelno,
                "msg": record.getMessage(),
                "created": record.created,
                "process_name": multiprocessing.current_process().name,
                "process_id": multiprocessing.current_process().pid,
                "service_id": self.service_id,
            }
            self.log_queue.put_nowait(log_data)
        except queue.Full:
            # Drop logs if queue is full to prevent blocking. Do not log to prevent recursion.
            pass
        except Exception:
            # Do not log to prevent recursion
            pass

emit(record)

Emit a log record to the queue.

Source code in aiperf/common/logging.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def emit(self, record: logging.LogRecord) -> None:
    """Emit a log record to the queue."""
    try:
        # Create a serializable log data structure
        log_data = {
            "name": record.name,
            "levelname": record.levelname,
            "levelno": record.levelno,
            "msg": record.getMessage(),
            "created": record.created,
            "process_name": multiprocessing.current_process().name,
            "process_id": multiprocessing.current_process().pid,
            "service_id": self.service_id,
        }
        self.log_queue.put_nowait(log_data)
    except queue.Full:
        # Drop logs if queue is full to prevent blocking. Do not log to prevent recursion.
        pass
    except Exception:
        # Do not log to prevent recursion
        pass

create_file_handler(log_folder, level)

Configure a file handler for logging.

Source code in aiperf/common/logging.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def create_file_handler(
    log_folder: Path,
    level: str | int,
) -> logging.FileHandler:
    """Configure a file handler for logging."""

    log_folder.mkdir(parents=True, exist_ok=True)
    log_file_path = log_folder / "aiperf.log"

    file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
    file_handler.setLevel(level)
    file_handler.setFormatter(
        logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
        )
    )
    return file_handler

get_global_log_queue() cached

Get the global log queue. Will create a new queue if it doesn't exist.

Source code in aiperf/common/logging.py
23
24
25
26
@lru_cache(maxsize=1)
def get_global_log_queue() -> multiprocessing.Queue:
    """Get the global log queue. Will create a new queue if it doesn't exist."""
    return multiprocessing.Queue(maxsize=LOG_QUEUE_MAXSIZE)

setup_child_process_logging(log_queue=None, service_id=None, service_config=None, user_config=None)

Set up logging for a child process to send logs to the main process.

This should be called early in child process initialization.

Parameters:

Name Type Description Default
log_queue Queue | None

The multiprocessing queue to send logs to. If None, tries to get the global queue.

None
service_id str | None

The ID of the service to log under. If None, logs will be under the process name.

None
service_config ServiceConfig | None

The service configuration used to determine the log level.

None
user_config UserConfig | None

The user configuration used to determine the log folder.

None
Source code in aiperf/common/logging.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def setup_child_process_logging(
    log_queue: "multiprocessing.Queue | None" = None,
    service_id: str | None = None,
    service_config: ServiceConfig | None = None,
    user_config: UserConfig | None = None,
) -> None:
    """Set up logging for a child process to send logs to the main process.

    This should be called early in child process initialization.

    Args:
        log_queue: The multiprocessing queue to send logs to. If None, tries to get the global queue.
        service_id: The ID of the service to log under. If None, logs will be under the process name.
        service_config: The service configuration used to determine the log level.
        user_config: The user configuration used to determine the log folder.
    """
    root_logger = logging.getLogger()
    level = ServiceDefaults.LOG_LEVEL.upper()
    if service_config:
        level = service_config.log_level.upper()

        if service_id:
            # If the service is in the trace or debug services, set the level to trace or debug
            if service_config.trace_services and _is_service_in_types(
                service_id, service_config.trace_services
            ):
                level = _TRACE
            elif service_config.debug_services and _is_service_in_types(
                service_id, service_config.debug_services
            ):
                level = _DEBUG

    # Set the root logger level to ensure logs are passed to handlers
    root_logger.setLevel(level)

    # Remove all existing handlers to avoid duplicate logs
    for existing_handler in root_logger.handlers[:]:
        root_logger.removeHandler(existing_handler)

    if log_queue is not None:
        # Set up handler for child process
        queue_handler = MultiProcessLogHandler(log_queue, service_id)
        queue_handler.setLevel(level)
        root_logger.addHandler(queue_handler)

    if service_config:
        # Set up rich logging to the console
        rich_handler = RichHandler(
            rich_tracebacks=True,
            show_path=True,
            console=Console(),
            show_time=True,
            show_level=True,
            tracebacks_show_locals=False,
            log_time_format="%H:%M:%S.%f",
            omit_repeated_times=False,
        )
        rich_handler.setLevel(level)
        root_logger.addHandler(rich_handler)

    if user_config and user_config.output.artifact_directory:
        file_handler = create_file_handler(
            user_config.output.artifact_directory / "logs", level
        )
        root_logger.addHandler(file_handler)

setup_rich_logging(user_config, service_config)

Set up rich logging with appropriate configuration.

Source code in aiperf/common/logging.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def setup_rich_logging(user_config: UserConfig, service_config: ServiceConfig) -> None:
    """Set up rich logging with appropriate configuration."""
    # Set logging level for the root logger (affects all loggers)
    level = service_config.log_level.upper()
    logging.root.setLevel(level)

    rich_handler = RichHandler(
        rich_tracebacks=True,
        show_path=True,
        console=Console(),
        show_time=True,
        show_level=True,
        tracebacks_show_locals=False,
        log_time_format="%H:%M:%S.%f",
        omit_repeated_times=False,
    )
    logging.root.addHandler(rich_handler)

    # Enable file logging for services
    # TODO: Use config to determine if file logging is enabled and the folder path.
    log_folder = user_config.output.artifact_directory / "logs"
    log_folder.mkdir(parents=True, exist_ok=True)
    file_handler = logging.FileHandler(log_folder / "aiperf.log")
    file_handler.setLevel(level)
    file_handler.formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logging.root.addHandler(file_handler)

    _logger.debug(lambda: f"Logging initialized with level: {level}")

aiperf.common.messages.base_messages

ErrorMessage

Bases: Message

Message containing error data.

Source code in aiperf/common/messages/base_messages.py
112
113
114
115
116
117
class ErrorMessage(Message):
    """Message containing error data."""

    message_type: MessageTypeT = MessageType.ERROR

    error: ErrorDetails = Field(..., description="Error information")

Message

Bases: AIPerfBaseModel

Base message class for optimized message handling. Based on the AIPerfBaseModel class, so it supports @exclude_if_none decorator. see :class:AIPerfBaseModel for more details.

This class provides a base for all messages, including common fields like message_type, request_ns, and request_id. It also supports optional field exclusion based on the @exclude_if_none decorator.

Each message model should inherit from this class, set the message_type field, and define its own additional fields.

Example:

@exclude_if_none("some_field")
class ExampleMessage(Message):
    some_field: int | None = Field(default=None)
    other_field: int = Field(default=1)
Source code in aiperf/common/messages/base_messages.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@exclude_if_none("request_ns", "request_id")
class Message(AIPerfBaseModel):
    """Base message class for optimized message handling. Based on the AIPerfBaseModel class,
    so it supports @exclude_if_none decorator. see :class:`AIPerfBaseModel` for more details.

    This class provides a base for all messages, including common fields like message_type,
    request_ns, and request_id. It also supports optional field exclusion based on the
    @exclude_if_none decorator.

    Each message model should inherit from this class, set the message_type field,
    and define its own additional fields.

    Example:
    ```python
    @exclude_if_none("some_field")
    class ExampleMessage(Message):
        some_field: int | None = Field(default=None)
        other_field: int = Field(default=1)
    ```
    """

    _message_type_lookup: ClassVar[dict[MessageTypeT, type["Message"]]] = {}
    """Lookup table for message types to their corresponding message classes. This is used to automatically
    deserialize messages from JSON strings to their corresponding class type."""

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if hasattr(cls, "message_type") and cls.message_type is not None:
            # Store concrete message classes in the lookup table
            cls._message_type_lookup[cls.message_type] = cls
            _logger.trace(f"Added {cls.message_type} to message type lookup")

    message_type: MessageTypeT = Field(
        ...,
        description="The type of the message. Must be set in the subclass.",
    )

    request_ns: int | None = Field(
        default=None,
        description="Timestamp of the request",
    )

    request_id: str | None = Field(
        default=None,
        description="ID of the request",
    )

    # TODO: Does this allow you to use model_validate_json and have it forward it to from_json? Need to test.
    @classmethod
    def __get_validators__(cls):
        yield cls.from_json

    @classmethod
    def from_json(cls, json_str: str | bytes | bytearray) -> "Message":
        """Deserialize a message from a JSON string, attempting to auto-detect the message type.
        NOTE: If you already know the message type, use the more performant :meth:`from_json_with_type` instead."""
        data = json.loads(json_str)
        message_type = data.get("message_type")
        if not message_type:
            raise ValueError(f"Missing message_type: {json_str}")

        # Use cached message type lookup
        message_class = cls._message_type_lookup[message_type]
        if not message_class:
            raise ValueError(f"Unknown message type: {message_type}")

        return message_class.model_validate(data)

    @classmethod
    def from_json_with_type(
        cls, message_type: MessageTypeT, json_str: str | bytes | bytearray
    ) -> "Message":
        """Deserialize a message from a JSON string with a specific message type.
        NOTE: This is more performant than :meth:`from_json` because it does not need to
        convert the JSON string to a dictionary first."""
        # Use cached message type lookup
        message_class = cls._message_type_lookup[message_type]
        if not message_class:
            raise ValueError(f"Unknown message type: {message_type}")
        return message_class.model_validate_json(json_str)

    def __str__(self) -> str:
        return self.model_dump_json()

from_json(json_str) classmethod

Deserialize a message from a JSON string, attempting to auto-detect the message type. NOTE: If you already know the message type, use the more performant :meth:from_json_with_type instead.

Source code in aiperf/common/messages/base_messages.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@classmethod
def from_json(cls, json_str: str | bytes | bytearray) -> "Message":
    """Deserialize a message from a JSON string, attempting to auto-detect the message type.
    NOTE: If you already know the message type, use the more performant :meth:`from_json_with_type` instead."""
    data = json.loads(json_str)
    message_type = data.get("message_type")
    if not message_type:
        raise ValueError(f"Missing message_type: {json_str}")

    # Use cached message type lookup
    message_class = cls._message_type_lookup[message_type]
    if not message_class:
        raise ValueError(f"Unknown message type: {message_type}")

    return message_class.model_validate(data)

from_json_with_type(message_type, json_str) classmethod

Deserialize a message from a JSON string with a specific message type. NOTE: This is more performant than :meth:from_json because it does not need to convert the JSON string to a dictionary first.

Source code in aiperf/common/messages/base_messages.py
86
87
88
89
90
91
92
93
94
95
96
97
@classmethod
def from_json_with_type(
    cls, message_type: MessageTypeT, json_str: str | bytes | bytearray
) -> "Message":
    """Deserialize a message from a JSON string with a specific message type.
    NOTE: This is more performant than :meth:`from_json` because it does not need to
    convert the JSON string to a dictionary first."""
    # Use cached message type lookup
    message_class = cls._message_type_lookup[message_type]
    if not message_class:
        raise ValueError(f"Unknown message type: {message_type}")
    return message_class.model_validate_json(json_str)

RequiresRequestNSMixin

Bases: Message

Mixin for messages that require a request_ns field.

Source code in aiperf/common/messages/base_messages.py
103
104
105
106
107
108
109
class RequiresRequestNSMixin(Message):
    """Mixin for messages that require a request_ns field."""

    request_ns: int = Field(  # type: ignore[assignment]
        default_factory=time.time_ns,
        description="Timestamp of the request in nanoseconds",
    )

aiperf.common.messages.command_messages

CommandMessage

Bases: TargetedServiceMessage

Message containing command data. This message is sent by the system controller to a service to command it to do something.

Source code in aiperf/common/messages/command_messages.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class CommandMessage(TargetedServiceMessage):
    """Message containing command data.
    This message is sent by the system controller to a service to command it to do something.
    """

    _command_type_lookup: ClassVar[dict[CommandTypeT, type["CommandMessage"]]] = {}

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if hasattr(cls, "command"):
            cls._command_type_lookup[cls.command] = cls

    message_type: MessageTypeT = MessageType.COMMAND

    command: CommandTypeT = Field(
        ...,
        description="Command to execute",
    )
    command_id: str = Field(
        default_factory=lambda: str(uuid.uuid4()),
        description="Unique identifier for this command. If not provided, a random UUID will be generated.",
    )

    @classmethod
    def from_json(cls, json_str: str | bytes | bytearray) -> "CommandMessage":
        """Deserialize a command message from a JSON string, attempting to auto-detect the command type."""
        data = json.loads(json_str)
        command_type = data.get("command")
        if not command_type:
            raise ValueError(f"Missing command: {json_str}")

        # Use cached command type lookup
        command_class = cls._command_type_lookup[command_type]
        if not command_class:
            _logger.debug(
                lambda: f"No command class found for command type: {command_type}"
            )
            # fallback to regular command class
            command_class = cls

        return command_class.model_validate(data)

from_json(json_str) classmethod

Deserialize a command message from a JSON string, attempting to auto-detect the command type.

Source code in aiperf/common/messages/command_messages.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@classmethod
def from_json(cls, json_str: str | bytes | bytearray) -> "CommandMessage":
    """Deserialize a command message from a JSON string, attempting to auto-detect the command type."""
    data = json.loads(json_str)
    command_type = data.get("command")
    if not command_type:
        raise ValueError(f"Missing command: {json_str}")

    # Use cached command type lookup
    command_class = cls._command_type_lookup[command_type]
    if not command_class:
        _logger.debug(
            lambda: f"No command class found for command type: {command_type}"
        )
        # fallback to regular command class
        command_class = cls

    return command_class.model_validate(data)

CommandResponse

Bases: TargetedServiceMessage

Message containing a command response.

Source code in aiperf/common/messages/command_messages.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class CommandResponse(TargetedServiceMessage):
    """Message containing a command response."""

    # Specialized lookup for command response messages by status
    _command_status_lookup: ClassVar[
        dict[CommandResponseStatus, type["CommandResponse"]]
    ] = {}
    # Specialized lookup for command response messages by command type, for success messages
    _command_success_type_lookup: ClassVar[
        dict[CommandTypeT, type["CommandResponse"]]
    ] = {}

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if (
            hasattr(cls, "status")
            and cls.status is not None
            and cls.status not in cls._command_status_lookup
        ):
            cls._command_status_lookup[cls.status] = cls
        elif (
            cls.__pydantic_fields__.get("status") is not None
            and cls.__pydantic_fields__.get("status").default
            == CommandResponseStatus.SUCCESS
        ):
            # Cache the specialized lookup by command type for success messages
            cls._command_success_type_lookup[cls.command] = cls

    message_type: MessageTypeT = MessageType.COMMAND_RESPONSE

    command: CommandTypeT = Field(
        ...,
        description="Command type that is being responded to",
    )
    command_id: str = Field(
        ..., description="The ID of the command that is being responded to"
    )
    status: CommandResponseStatus = Field(..., description="The status of the command")

    @classmethod
    def from_json(cls, json_str: str | bytes | bytearray) -> "CommandResponse":
        """Deserialize a command response message from a JSON string, attempting to auto-detect the command response type."""
        data = json.loads(json_str)
        status = data.get("status")
        if not status:
            raise ValueError(f"Missing command response status: {json_str}")
        command = data.get("command")
        if not command:
            raise ValueError(f"Missing command in command response: {json_str}")

        if status not in cls._command_status_lookup:
            raise ValueError(
                f"Unknown command response status: {status}. Valid statuses are: {list(cls._command_status_lookup.keys())}"
            )

        # Use cached command response type lookup by status
        command_response_class = cls._command_status_lookup[status]

        if (
            status == CommandResponseStatus.SUCCESS
            and command in cls._command_success_type_lookup
        ):
            # For success messages, use the specialized lookup by command type if it exists
            command_response_class = cls._command_success_type_lookup[command]

        return command_response_class.model_validate(data)

from_json(json_str) classmethod

Deserialize a command response message from a JSON string, attempting to auto-detect the command response type.

Source code in aiperf/common/messages/command_messages.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
@classmethod
def from_json(cls, json_str: str | bytes | bytearray) -> "CommandResponse":
    """Deserialize a command response message from a JSON string, attempting to auto-detect the command response type."""
    data = json.loads(json_str)
    status = data.get("status")
    if not status:
        raise ValueError(f"Missing command response status: {json_str}")
    command = data.get("command")
    if not command:
        raise ValueError(f"Missing command in command response: {json_str}")

    if status not in cls._command_status_lookup:
        raise ValueError(
            f"Unknown command response status: {status}. Valid statuses are: {list(cls._command_status_lookup.keys())}"
        )

    # Use cached command response type lookup by status
    command_response_class = cls._command_status_lookup[status]

    if (
        status == CommandResponseStatus.SUCCESS
        and command in cls._command_success_type_lookup
    ):
        # For success messages, use the specialized lookup by command type if it exists
        command_response_class = cls._command_success_type_lookup[command]

    return command_response_class.model_validate(data)

CommandSuccessResponse

Bases: CommandResponse

Generic command response message when a command succeeds. It should be subclassed for specific command types.

Source code in aiperf/common/messages/command_messages.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class CommandSuccessResponse(CommandResponse):
    """Generic command response message when a command succeeds. It should be
    subclassed for specific command types."""

    status: CommandResponseStatus = CommandResponseStatus.SUCCESS
    data: Any | None = Field(
        default=None,
        description="The data of the command response",
    )

    @classmethod
    def from_command_message(
        cls, command_message: CommandMessage, service_id: str, data: Any | None = None
    ) -> Self:
        return cls(
            service_id=service_id,
            target_service_id=command_message.service_id,
            command=command_message.command,
            command_id=command_message.command_id,
            data=data,
        )

ConnectionProbeMessage

Bases: TargetedServiceMessage

Message containing a connection probe from a service. This is used to probe the connection to the service.

Source code in aiperf/common/messages/command_messages.py
345
346
347
348
class ConnectionProbeMessage(TargetedServiceMessage):
    """Message containing a connection probe from a service. This is used to probe the connection to the service."""

    message_type: MessageTypeT = MessageType.CONNECTION_PROBE

ProcessRecordsCommand

Bases: CommandMessage

Data to send with the process records command.

Source code in aiperf/common/messages/command_messages.py
284
285
286
287
288
289
290
291
292
class ProcessRecordsCommand(CommandMessage):
    """Data to send with the process records command."""

    command: CommandTypeT = CommandType.PROCESS_RECORDS

    cancelled: bool = Field(
        default=False,
        description="Whether the profile run was cancelled",
    )

ProcessRecordsResponse

Bases: CommandSuccessResponse

Response to the process records command.

Source code in aiperf/common/messages/command_messages.py
334
335
336
337
338
339
340
341
342
class ProcessRecordsResponse(CommandSuccessResponse):
    """Response to the process records command."""

    command: CommandTypeT = CommandType.PROCESS_RECORDS

    data: ProcessRecordsResult | None = Field(  # type: ignore[assignment]
        default=None,
        description="The result of the process records command",
    )

ProfileCancelCommand

Bases: CommandMessage

Command message sent to request services to cancel profiling.

Source code in aiperf/common/messages/command_messages.py
310
311
312
313
class ProfileCancelCommand(CommandMessage):
    """Command message sent to request services to cancel profiling."""

    command: CommandTypeT = CommandType.PROFILE_CANCEL

ProfileConfigureCommand

Bases: CommandMessage

Data to send with the profile configure command.

Source code in aiperf/common/messages/command_messages.py
295
296
297
298
299
300
301
class ProfileConfigureCommand(CommandMessage):
    """Data to send with the profile configure command."""

    command: CommandTypeT = CommandType.PROFILE_CONFIGURE

    # TODO: Define this type
    config: Any = Field(..., description="Configuration for the profile")

ProfileStartCommand

Bases: CommandMessage

Command message sent to request services to start profiling.

Source code in aiperf/common/messages/command_messages.py
304
305
306
307
class ProfileStartCommand(CommandMessage):
    """Command message sent to request services to start profiling."""

    command: CommandTypeT = CommandType.PROFILE_START

RegisterServiceCommand

Bases: CommandMessage

Command message sent from a service to the system controller to register itself.

Source code in aiperf/common/messages/command_messages.py
322
323
324
325
326
327
328
329
330
331
class RegisterServiceCommand(CommandMessage):
    """Command message sent from a service to the system controller to register itself."""

    command: CommandTypeT = CommandType.REGISTER_SERVICE

    service_id: str = Field(..., description="The ID of the service to register")
    service_type: ServiceTypeT = Field(
        ..., description="The type of the service to register"
    )
    state: LifecycleState = Field(..., description="The current state of the service")

ShutdownCommand

Bases: CommandMessage

Command message sent to request a service to shutdown.

Source code in aiperf/common/messages/command_messages.py
316
317
318
319
class ShutdownCommand(CommandMessage):
    """Command message sent to request a service to shutdown."""

    command: CommandTypeT = CommandType.SHUTDOWN

TargetedServiceMessage

Bases: BaseServiceMessage

Message that can be targeted to a specific service by id or type. If both target_service_type and target_service_id are None, the message is sent to all services that are subscribed to the message type.

Source code in aiperf/common/messages/command_messages.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@exclude_if_none("target_service_id", "target_service_type")
class TargetedServiceMessage(BaseServiceMessage):
    """Message that can be targeted to a specific service by id or type.
    If both `target_service_type` and `target_service_id` are None, the message is
    sent to all services that are subscribed to the message type.
    """

    @model_validator(mode="after")
    def validate_target_service(self) -> Self:
        if self.target_service_id is not None and self.target_service_type is not None:
            raise ValueError(
                "Either target_service_id or target_service_type can be provided, but not both"
            )
        return self

    target_service_id: str | None = Field(
        default=None,
        description="ID of the target service to send the message to. "
        "If both `target_service_type` and `target_service_id` are None, the message is "
        "sent to all services that are subscribed to the message type.",
    )
    target_service_type: ServiceTypeT | None = Field(
        default=None,
        description="Type of the service to send the message to. "
        "If both `target_service_type` and `target_service_id` are None, the message is "
        "sent to all services that are subscribed to the message type.",
    )

aiperf.common.messages.credit_messages

CreditDropMessage

Bases: BaseServiceMessage

Message indicating that a credit has been dropped. This message is sent by the timing manager to workers to indicate that credit(s) have been dropped.

Source code in aiperf/common/messages/credit_messages.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class CreditDropMessage(BaseServiceMessage):
    """Message indicating that a credit has been dropped.
    This message is sent by the timing manager to workers to indicate that credit(s)
    have been dropped.
    """

    message_type: MessageTypeT = MessageType.CREDIT_DROP

    phase: CreditPhase = Field(..., description="The type of credit phase")
    conversation_id: str | None = Field(
        default=None, description="The ID of the conversation, if applicable."
    )
    credit_drop_ns: int | None = Field(
        default=None,
        description="Timestamp of the credit drop, if applicable. None means send ASAP.",
    )

CreditPhaseCompleteMessage

Bases: BaseServiceMessage

Message for credit phase complete. Sent by the TimingManager to report that a credit phase has completed.

Source code in aiperf/common/messages/credit_messages.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class CreditPhaseCompleteMessage(BaseServiceMessage):
    """Message for credit phase complete. Sent by the TimingManager to report that a credit phase has completed."""

    message_type: MessageTypeT = MessageType.CREDIT_PHASE_COMPLETE
    phase: CreditPhase = Field(..., description="The type of credit phase")
    completed: int = Field(
        ...,
        description="The number of completed credits (returned from the workers). This is the final count of completed credits.",
    )
    end_ns: int | None = Field(
        default=None,
        ge=1,
        description="The time in which the last credit was returned from the workers in nanoseconds. If None, the phase has not completed.",
    )

CreditPhaseProgressMessage

Bases: BaseServiceMessage

Sent by the TimingManager to report the progress of a credit phase.

Source code in aiperf/common/messages/credit_messages.py
81
82
83
84
85
86
87
88
89
90
class CreditPhaseProgressMessage(BaseServiceMessage):
    """Sent by the TimingManager to report the progress of a credit phase."""

    message_type: MessageTypeT = MessageType.CREDIT_PHASE_PROGRESS
    phase: CreditPhase = Field(..., description="The type of credit phase")
    sent: int = Field(default=0, description="The number of sent credits")
    completed: int = Field(
        default=0,
        description="The number of completed credits (returned from the workers)",
    )

CreditPhaseSendingCompleteMessage

Bases: BaseServiceMessage

Message for credit phase sending complete. Sent by the TimingManager to report that a credit phase has completed sending.

Source code in aiperf/common/messages/credit_messages.py
 93
 94
 95
 96
 97
 98
 99
100
101
class CreditPhaseSendingCompleteMessage(BaseServiceMessage):
    """Message for credit phase sending complete. Sent by the TimingManager to report that a credit phase has completed sending."""

    message_type: MessageTypeT = MessageType.CREDIT_PHASE_SENDING_COMPLETE
    phase: CreditPhase = Field(..., description="The type of credit phase")
    sent_end_ns: int | None = Field(
        default=None,
        description="The time of the last sent credit in nanoseconds. If None, the phase has not sent all credits.",
    )

CreditPhaseStartMessage

Bases: BaseServiceMessage

Message for credit phase start. Sent by the TimingManager to report that a credit phase has started.

Source code in aiperf/common/messages/credit_messages.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class CreditPhaseStartMessage(BaseServiceMessage):
    """Message for credit phase start. Sent by the TimingManager to report that a credit phase has started."""

    message_type: MessageTypeT = MessageType.CREDIT_PHASE_START
    phase: CreditPhase = Field(..., description="The type of credit phase")
    start_ns: int = Field(
        ge=1,
        description="The start time of the credit phase in nanoseconds.",
    )
    total_expected_requests: int | None = Field(
        default=None,
        ge=1,
        description="The total number of expected requests. If None, the phase is not request count based.",
    )
    expected_duration_sec: float | None = Field(
        default=None,
        ge=1,
        description="The expected duration of the credit phase in seconds. If None, the phase is not time based.",
    )

CreditReturnMessage

Bases: BaseServiceMessage

Message indicating that a credit has been returned. This message is sent by a worker to the timing manager to indicate that work has been completed.

Source code in aiperf/common/messages/credit_messages.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class CreditReturnMessage(BaseServiceMessage):
    """Message indicating that a credit has been returned.
    This message is sent by a worker to the timing manager to indicate that work has
    been completed.
    """

    message_type: MessageTypeT = MessageType.CREDIT_RETURN

    phase: CreditPhase = Field(
        ...,
        description="The Credit Phase of the credit drop. This is so the TimingManager can track the progress of the credit phase.",
    )
    delayed_ns: int | None = Field(
        default=None,
        ge=1,
        description="The number of nanoseconds the credit drop was delayed by, or None if the credit was sent on time. "
        "NOTE: This is only applicable if the original credit_drop_ns was not None.",
    )
    # TODO: Does it make more sense for this to be part of the RequestRecord?
    pre_inference_ns: int | None = Field(
        default=None,
        description="The latency of the credit in nanoseconds from when it was first received to when the inference request was sent. "
        "This can be used to trace the latency in order to identify bottlenecks or other issues.",
        ge=0,
    )

    @property
    def delayed(self) -> bool:
        return self.delayed_ns is not None

CreditsCompleteMessage

Bases: BaseServiceMessage

Credits complete message sent by the TimingManager to the System controller to signify all Credit Phases have been completed.

Source code in aiperf/common/messages/credit_messages.py
120
121
122
123
124
class CreditsCompleteMessage(BaseServiceMessage):
    """Credits complete message sent by the TimingManager to the System controller to signify all Credit Phases
    have been completed."""

    message_type: MessageTypeT = MessageType.CREDITS_COMPLETE

aiperf.common.messages.dataset_messages

ConversationRequestMessage

Bases: BaseServiceMessage

Message to request a full conversation by ID.

Source code in aiperf/common/messages/dataset_messages.py
12
13
14
15
16
17
18
19
20
21
22
23
class ConversationRequestMessage(BaseServiceMessage):
    """Message to request a full conversation by ID."""

    message_type: MessageTypeT = MessageType.CONVERSATION_REQUEST

    conversation_id: str | None = Field(
        default=None, description="The session ID of the conversation"
    )
    credit_phase: CreditPhase | None = Field(
        default=None,
        description="The type of credit phase (either warmup or profiling). If not provided, the timing manager will use the default credit phase.",
    )

ConversationResponseMessage

Bases: BaseServiceMessage

Message containing a full conversation.

Source code in aiperf/common/messages/dataset_messages.py
26
27
28
29
30
class ConversationResponseMessage(BaseServiceMessage):
    """Message containing a full conversation."""

    message_type: MessageTypeT = MessageType.CONVERSATION_RESPONSE
    conversation: Conversation = Field(..., description="The conversation data")

ConversationTurnRequestMessage

Bases: BaseServiceMessage

Message to request a single turn from a conversation.

Source code in aiperf/common/messages/dataset_messages.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class ConversationTurnRequestMessage(BaseServiceMessage):
    """Message to request a single turn from a conversation."""

    message_type: MessageTypeT = MessageType.CONVERSATION_TURN_REQUEST

    conversation_id: str = Field(
        ...,
        description="The ID of the conversation.",
    )
    turn_index: int = Field(
        ...,
        ge=0,
        description="The index of the turn in the conversation.",
    )

ConversationTurnResponseMessage

Bases: BaseServiceMessage

Message containing a single turn from a conversation.

Source code in aiperf/common/messages/dataset_messages.py
49
50
51
52
53
54
class ConversationTurnResponseMessage(BaseServiceMessage):
    """Message containing a single turn from a conversation."""

    message_type: MessageTypeT = MessageType.CONVERSATION_TURN_RESPONSE

    turn: Turn = Field(..., description="The turn data")

DatasetConfiguredNotification

Bases: BaseServiceMessage

Notification sent to notify other services that the dataset has been configured.

Source code in aiperf/common/messages/dataset_messages.py
74
75
76
77
class DatasetConfiguredNotification(BaseServiceMessage):
    """Notification sent to notify other services that the dataset has been configured."""

    message_type: MessageTypeT = MessageType.DATASET_CONFIGURED_NOTIFICATION

DatasetTimingRequest

Bases: BaseServiceMessage

Message for a dataset timing request.

Source code in aiperf/common/messages/dataset_messages.py
57
58
59
60
class DatasetTimingRequest(BaseServiceMessage):
    """Message for a dataset timing request."""

    message_type: MessageTypeT = MessageType.DATASET_TIMING_REQUEST

DatasetTimingResponse

Bases: BaseServiceMessage

Message for a dataset timing response.

Source code in aiperf/common/messages/dataset_messages.py
63
64
65
66
67
68
69
70
71
class DatasetTimingResponse(BaseServiceMessage):
    """Message for a dataset timing response."""

    message_type: MessageTypeT = MessageType.DATASET_TIMING_RESPONSE

    timing_data: list[tuple[int, str]] = Field(
        ...,
        description="The timing data of the dataset. Tuple of (timestamp, conversation_id)",
    )

aiperf.common.messages.health_messages

WorkerHealthMessage

Bases: BaseServiceMessage

Message for a worker health check.

Source code in aiperf/common/messages/health_messages.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class WorkerHealthMessage(BaseServiceMessage):
    """Message for a worker health check."""

    message_type: MessageTypeT = MessageType.WORKER_HEALTH

    process: ProcessHealth = Field(..., description="The health of the worker process")

    # Worker specific fields
    task_stats: dict[CreditPhase, WorkerPhaseTaskStats] = Field(
        ...,
        description="Stats for the tasks that have been sent to the worker, keyed by the credit phase",
    )

    @property
    def total_tasks(self) -> int:
        """The total number of tasks that have been sent to the worker."""
        return sum(task_stats.total for task_stats in self.task_stats.values())

    @property
    def completed_tasks(self) -> int:
        """The number of tasks that have been completed by the worker."""
        return sum(task_stats.completed for task_stats in self.task_stats.values())

    @property
    def failed_tasks(self) -> int:
        """The number of tasks that have failed by the worker."""
        return sum(task_stats.failed for task_stats in self.task_stats.values())

    @property
    def in_progress_tasks(self) -> int:
        """The number of tasks that are currently in progress by the worker."""
        return sum(task_stats.in_progress for task_stats in self.task_stats.values())

    @property
    def error_rate(self) -> float:
        """The error rate of the worker."""
        if self.total_tasks == 0:
            return 0
        return self.failed_tasks / self.total_tasks

completed_tasks property

The number of tasks that have been completed by the worker.

error_rate property

The error rate of the worker.

failed_tasks property

The number of tasks that have failed by the worker.

in_progress_tasks property

The number of tasks that are currently in progress by the worker.

total_tasks property

The total number of tasks that have been sent to the worker.

aiperf.common.messages.inference_messages

InferenceResultsMessage

Bases: BaseServiceMessage

Message for a inference results.

Source code in aiperf/common/messages/inference_messages.py
19
20
21
22
23
24
25
26
class InferenceResultsMessage(BaseServiceMessage):
    """Message for a inference results."""

    message_type: MessageTypeT = MessageType.INFERENCE_RESULTS

    record: SerializeAsAny[RequestRecord] = Field(
        ..., description="The inference results record"
    )

MetricRecordsMessage

Bases: BaseServiceMessage

Message from the result parser to the records manager to notify it of the metric records for a single request.

Source code in aiperf/common/messages/inference_messages.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class MetricRecordsMessage(BaseServiceMessage):
    """Message from the result parser to the records manager to notify it
    of the metric records for a single request."""

    message_type: MessageTypeT = MessageType.METRIC_RECORDS

    worker_id: str = Field(
        ..., description="The ID of the worker that processed the request."
    )
    credit_phase: CreditPhase = Field(
        ..., description="The credit phase of the request."
    )
    results: list[dict[MetricTagT, MetricValueTypeT]] = Field(
        ..., description="The record processor results"
    )
    error: ErrorDetails | None = Field(
        default=None, description="The error details if the request failed."
    )

    @property
    def valid(self) -> bool:
        """Whether the request was valid."""
        return self.error is None

valid property

Whether the request was valid.

ParsedInferenceResultsMessage

Bases: BaseServiceMessage

Message for a parsed inference results.

Source code in aiperf/common/messages/inference_messages.py
29
30
31
32
33
34
35
36
37
38
39
class ParsedInferenceResultsMessage(BaseServiceMessage):
    """Message for a parsed inference results."""

    message_type: MessageTypeT = MessageType.PARSED_INFERENCE_RESULTS

    worker_id: str = Field(
        ..., description="The ID of the worker that processed the request."
    )
    record: SerializeAsAny[ParsedResponseRecord] = Field(
        ..., description="The post process results record"
    )

aiperf.common.messages.progress_messages

AllRecordsReceivedMessage

Bases: BaseServiceMessage, RequiresRequestNSMixin

This is sent by the RecordsManager to signal that all parsed records have been received, and the final processing stats are available.

Source code in aiperf/common/messages/progress_messages.py
 99
100
101
102
103
104
105
class AllRecordsReceivedMessage(BaseServiceMessage, RequiresRequestNSMixin):
    """This is sent by the RecordsManager to signal that all parsed records have been received, and the final processing stats are available."""

    message_type: MessageTypeT = MessageType.ALL_RECORDS_RECEIVED
    final_processing_stats: PhaseProcessingStats = Field(
        ..., description="The final processing stats for the profile run"
    )

ProcessRecordsResultMessage

Bases: BaseServiceMessage

Message for process records result.

Source code in aiperf/common/messages/progress_messages.py
108
109
110
111
112
113
class ProcessRecordsResultMessage(BaseServiceMessage):
    """Message for process records result."""

    message_type: MessageTypeT = MessageType.PROCESS_RECORDS_RESULT

    results: ProcessRecordsResult = Field(..., description="The process records result")

ProcessingStatsMessage

Bases: BaseServiceMessage

Message for processing stats. Sent by the records manager to the system controller to report the stats of the profile run.

Source code in aiperf/common/messages/progress_messages.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class ProcessingStatsMessage(BaseServiceMessage):
    """Message for processing stats. Sent by the records manager to the system controller to report the stats of the profile run."""

    message_type: MessageTypeT = MessageType.PROCESSING_STATS

    error_count: int = Field(default=0, description="The number of errors encountered")
    completed: int = Field(
        default=0, description="The number of requests processed by the records manager"
    )
    worker_completed: dict[str, int] = Field(
        default_factory=dict,
        description="Per-worker request completion counts, keyed by worker service_id",
    )
    worker_errors: dict[str, int] = Field(
        default_factory=dict,
        description="Per-worker error counts, keyed by worker service_id",
    )

ProfileProgressMessage

Bases: BaseServiceMessage

Message for profile progress. Sent by the timing manager to the system controller to report the progress of the profile run.

Source code in aiperf/common/messages/progress_messages.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class ProfileProgressMessage(BaseServiceMessage):
    """Message for profile progress. Sent by the timing manager to the system controller to report the progress of the profile run."""

    message_type: MessageTypeT = MessageType.PROFILE_PROGRESS

    profile_id: str | None = Field(
        default=None, description="The ID of the current profile"
    )
    start_ns: int = Field(
        ..., description="The start time of the profile run in nanoseconds"
    )
    end_ns: int | None = Field(
        default=None, description="The end time of the profile run in nanoseconds"
    )
    total: int = Field(
        ..., description="The total number of inference requests to be made (if known)"
    )
    completed: int = Field(
        ..., description="The number of inference requests completed"
    )
    warmup: bool = Field(
        default=False,
        description="Whether this is the warmup phase of the profile run",
    )

ProfileResultsMessage

Bases: BaseServiceMessage

Message for profile results.

Source code in aiperf/common/messages/progress_messages.py
91
92
93
94
95
96
class ProfileResultsMessage(BaseServiceMessage):
    """Message for profile results."""

    message_type: MessageTypeT = MessageType.PROFILE_RESULTS

    profile_results: ProfileResults = Field(..., description="The profile results")

RecordsProcessingStatsMessage

Bases: BaseServiceMessage

Message for processing stats. Sent by the RecordsManager to report the stats of the profile run. This contains the stats for a single credit phase only.

Source code in aiperf/common/messages/progress_messages.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class RecordsProcessingStatsMessage(BaseServiceMessage):
    """Message for processing stats. Sent by the RecordsManager to report the stats of the profile run.
    This contains the stats for a single credit phase only."""

    message_type: MessageTypeT = MessageType.PROCESSING_STATS

    processing_stats: PhaseProcessingStats = Field(
        ..., description="The stats for the credit phase"
    )
    worker_stats: dict[str, PhaseProcessingStats] = Field(
        default_factory=dict,
        description="The stats for each worker how many requests were processed and how many errors were "
        "encountered, keyed by worker service_id",
    )

SweepProgressMessage

Bases: BaseServiceMessage

Message for sweep progress.

Source code in aiperf/common/messages/progress_messages.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class SweepProgressMessage(BaseServiceMessage):
    """Message for sweep progress."""

    # TODO: add profile information

    message_type: MessageTypeT = MessageType.SWEEP_PROGRESS

    sweep_id: str = Field(..., description="The ID of the current sweep")
    sweep_start_ns: int = Field(
        ..., description="The start time of the sweep in nanoseconds"
    )
    end_ns: int | None = Field(
        default=None, description="The end time of the profile run in nanoseconds"
    )

aiperf.common.messages.service_messages

BaseServiceErrorMessage

Bases: BaseServiceMessage

Base message containing error data.

Source code in aiperf/common/messages/service_messages.py
92
93
94
95
96
97
class BaseServiceErrorMessage(BaseServiceMessage):
    """Base message containing error data."""

    message_type: MessageTypeT = MessageType.SERVICE_ERROR

    error: ErrorDetails = Field(..., description="Error information")

BaseServiceMessage

Bases: Message

Base message that is sent from a service. Requires a service_id field to specify the service that sent the message.

Source code in aiperf/common/messages/service_messages.py
21
22
23
24
25
26
27
28
class BaseServiceMessage(Message):
    """Base message that is sent from a service. Requires a service_id field to specify
    the service that sent the message."""

    service_id: str = Field(
        ...,
        description="ID of the service sending the message",
    )

BaseStatusMessage

Bases: BaseServiceMessage

Base message containing status data. This message is sent by a service to the system controller to report its status.

Source code in aiperf/common/messages/service_messages.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class BaseStatusMessage(BaseServiceMessage):
    """Base message containing status data.
    This message is sent by a service to the system controller to report its status.
    """

    # override request_ns to be auto-filled if not provided
    request_ns: int | None = Field(
        default=time.time_ns(),
        description="Timestamp of the request",
    )
    state: LifecycleState = Field(
        ...,
        description="Current state of the service",
    )
    service_type: ServiceTypeT = Field(
        ...,
        description="Type of service",
    )

HeartbeatMessage

Bases: BaseStatusMessage

Message containing heartbeat data. This message is sent by a service to the system controller to indicate that it is still running.

Source code in aiperf/common/messages/service_messages.py
67
68
69
70
71
72
73
class HeartbeatMessage(BaseStatusMessage):
    """Message containing heartbeat data.
    This message is sent by a service to the system controller to indicate that it is
    still running.
    """

    message_type: MessageTypeT = MessageType.HEARTBEAT

NotificationMessage

Bases: BaseServiceMessage

Message containing a notification from a service. This is used to notify other services of events.

Source code in aiperf/common/messages/service_messages.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class NotificationMessage(BaseServiceMessage):
    """Message containing a notification from a service. This is used to notify other services of events."""

    message_type: MessageTypeT = MessageType.NOTIFICATION

    notification_type: NotificationType = Field(
        ...,
        description="The type of notification",
    )

    data: SerializeAsAny[BaseModel | None] = Field(
        default=None,
        description="Data to send with the notification",
    )

RegistrationMessage

Bases: BaseStatusMessage

Message containing registration data. This message is sent by a service to the system controller to register itself.

Source code in aiperf/common/messages/service_messages.py
59
60
61
62
63
64
class RegistrationMessage(BaseStatusMessage):
    """Message containing registration data.
    This message is sent by a service to the system controller to register itself.
    """

    message_type: MessageTypeT = MessageType.REGISTRATION

StatusMessage

Bases: BaseStatusMessage

Message containing status data. This message is sent by a service to the system controller to report its status.

Source code in aiperf/common/messages/service_messages.py
51
52
53
54
55
56
class StatusMessage(BaseStatusMessage):
    """Message containing status data.
    This message is sent by a service to the system controller to report its status.
    """

    message_type: MessageTypeT = MessageType.STATUS

aiperf.common.mixins.aiperf_lifecycle_mixin

AIPerfLifecycleMixin

Bases: TaskManagerMixin, HooksMixin

This mixin provides a lifecycle state machine, and is the basis for most components in the AIPerf framework. It provides a set of hooks that are run at each state transition, and the ability to define background tasks that are automatically ran on @on_start, and canceled via @on_stop.

It exposes to the outside world initialize, start, and stop methods, as well as getting the current state of the lifecycle. These simple methods promote a simple interface for users to interact with.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
@provides_hooks(
    AIPerfHook.ON_INIT,
    AIPerfHook.ON_START,
    AIPerfHook.ON_STOP,
    AIPerfHook.ON_STATE_CHANGE,
    AIPerfHook.BACKGROUND_TASK,
)
@implements_protocol(AIPerfLifecycleProtocol)
class AIPerfLifecycleMixin(TaskManagerMixin, HooksMixin):
    """This mixin provides a lifecycle state machine, and is the basis for most components in the AIPerf framework.
    It provides a set of hooks that are run at each state transition, and the ability to define background tasks
    that are automatically ran on @on_start, and canceled via @on_stop.

    It exposes to the outside world `initialize`, `start`, and `stop` methods, as well as getting the
    current state of the lifecycle. These simple methods promote a simple interface for users to interact with.
    """

    def __init__(self, id: str | None = None, **kwargs) -> None:
        """
        Args:
            id: The id of the lifecycle. If not provided, a random uuid will be generated.
        """
        self.id = id or f"{self.__class__.__name__}_{uuid.uuid4().hex[:8]}"
        self._state = LifecycleState.CREATED
        self.initialized_event = asyncio.Event()
        self.started_event = asyncio.Event()
        self._stop_requested_event = asyncio.Event()
        self.stopped_event = asyncio.Event()  # set on stop or failure
        self._children: list[AIPerfLifecycleProtocol] = []
        if "logger_name" not in kwargs:
            kwargs["logger_name"] = self.id
        super().__init__(**kwargs)

    @property
    def state(self) -> LifecycleState:
        return self._state

    # NOTE: This was moved to not be a property setter, as we want it to be async so we can
    # run the hooks and await them. Otherwise there is issues with creating a task when the
    # lifecycle is trying to stop.
    async def _set_state(self, state: LifecycleState) -> None:
        if state == self._state:
            return
        old_state = self._state
        self._state = state
        if self.is_debug_enabled:
            self.debug(f"State changed from {old_state!r} to {state!r} for {self}")
        await self.run_hooks(
            AIPerfHook.ON_STATE_CHANGE, old_state=old_state, new_state=state
        )

    @property
    def was_initialized(self) -> bool:
        return self.initialized_event.is_set()

    @property
    def was_started(self) -> bool:
        return self.started_event.is_set()

    @property
    def was_stopped(self) -> bool:
        return self.stopped_event.is_set()

    @property
    def is_running(self) -> bool:
        """Whether the lifecycle's current state is LifecycleState.RUNNING."""
        return self.state == LifecycleState.RUNNING

    @property
    def stop_requested(self) -> bool:
        """Whether the lifecycle has been requested to stop."""
        return self._stop_requested_event.is_set()

    @stop_requested.setter
    def stop_requested(self, value: bool) -> None:
        if value:
            self._stop_requested_event.set()
        else:
            self._stop_requested_event.clear()

    async def _execute_state_transition(
        self,
        transient_state: LifecycleState,
        final_state: LifecycleState,
        hook_type: AIPerfHook,
        event: asyncio.Event,
        reverse: bool = False,
    ) -> None:
        """This method wraps the functionality of changing the state of the lifecycle, and running the hooks.
        It is used to ensure that the state change and hook running are atomic, and that the state change is
        only made after the hooks have completed. It also takes in an event that is set when the state change is complete.
        This is useful for external code waiting for the state change to complete before continuing.

        If reverse is True, the hooks are run in reverse order. This is useful for stopping the lifecycle in the reverse order of starting it.
        """
        await self._set_state(transient_state)
        self.debug(lambda: f"{transient_state.title()} {self}")
        try:
            await self.run_hooks(hook_type, reverse=reverse)
            await self._set_state(final_state)
            self.debug(lambda: f"{self} is now {final_state.title()}")
            event.set()
        except Exception as e:
            await self._fail(e)

    async def initialize(self) -> None:
        """Initialize the lifecycle and run the @on_init hooks.

        NOTE: It is generally discouraged from overriding this method.
        Instead, use the @on_init hook to handle your own initialization logic.
        """
        if self.state in (
            LifecycleState.INITIALIZING,
            LifecycleState.INITIALIZED,
            LifecycleState.STARTING,
            LifecycleState.RUNNING,
        ):
            self.debug(
                lambda: f"Ignoring initialize request for {self} in state {self.state}"
            )
            return

        if self.state != LifecycleState.CREATED:
            raise InvalidStateError(
                f"Cannot initialize from state {self.state} for {self}"
            )

        await self._execute_state_transition(
            LifecycleState.INITIALIZING,
            LifecycleState.INITIALIZED,
            AIPerfHook.ON_INIT,
            self.initialized_event,
        )

    async def start(self) -> None:
        """Start the lifecycle and run the @on_start hooks.

        NOTE: It is generally discouraged from overriding this method.
        Instead, use the @on_start hook to handle your own starting logic.
        """
        if self.state in (
            LifecycleState.STARTING,
            LifecycleState.RUNNING,
        ):
            self.debug(
                lambda: f"Ignoring start request for {self} in state {self.state}"
            )
            return

        if self.state != LifecycleState.INITIALIZED:
            raise InvalidStateError(f"Cannot start from state {self.state} for {self}")

        await self._execute_state_transition(
            LifecycleState.STARTING,
            LifecycleState.RUNNING,
            AIPerfHook.ON_START,
            self.started_event,
        )

    async def initialize_and_start(self) -> None:
        """Initialize and start the lifecycle. This is a convenience method that calls `initialize` and `start` in sequence."""
        await self.initialize()
        await self.start()

    async def stop(self) -> None:
        """Stop the lifecycle and run the @on_stop hooks.

        NOTE: It is generally discouraged from overriding this method.
        Instead, use the @on_stop hook to handle your own stopping logic.
        """
        if self.stop_requested:
            self.debug(
                lambda: f"Ignoring stop request for {self} in state {self.state}"
            )
            return

        self.stop_requested = True
        await self._execute_state_transition(
            LifecycleState.STOPPING,
            LifecycleState.STOPPED,
            AIPerfHook.ON_STOP,
            self.stopped_event,
            reverse=True,  # run the stop hooks in reverse order
        )

    @on_start
    async def _start_background_tasks(self) -> None:
        """Start all tasks that are decorated with the @background_task decorator."""
        for hook in self.get_hooks(AIPerfHook.BACKGROUND_TASK):
            if not isinstance(hook.params, BackgroundTaskParams):
                raise AttributeError(
                    f"Invalid hook parameters for {hook}: {hook.params}. Expected BackgroundTaskParams."
                )
            self.start_background_task(
                hook.func,
                interval=hook.params.interval,
                immediate=hook.params.immediate,
                stop_on_error=hook.params.stop_on_error,
                stop_event=self._stop_requested_event,
            )

    @on_stop
    async def _stop_all_tasks(self) -> None:
        """Stop all tasks that are decorated with the @background_task decorator,
        and any custom ones that were ran using `self.execute_async()`.
        """
        await self.cancel_all_tasks()

    async def _fail(self, e: Exception) -> None:
        """Set the state to FAILED and raise an asyncio.CancelledError.
        This is used when the transition from one state to another fails.
        """
        await self._set_state(LifecycleState.FAILED)
        self.exception(f"Failed for {self}: {e}")
        self.stop_requested = True
        self.stopped_event.set()
        raise asyncio.CancelledError(f"Failed for {self}: {e}") from e

    def attach_child_lifecycle(self, child: AIPerfLifecycleProtocol) -> None:
        """Attach a child lifecycle to manage. This child will now have its lifecycle managed and
        controlled by this lifecycle. Common use cases are having a Service be a parent lifecycle,
        and having supporting components such as streaming post processors, progress reporters, etc. be children.

        Children will be called in the order they were attached for initialize and start,
        and in reverse order for stop.
        """
        if self.state != LifecycleState.CREATED:
            raise InvalidStateError(
                f"Cannot attach child {child} to {self} in state {self.state}. "
                "Please attach children before initializing or starting the lifecycle."
            )
        self._children.append(child)

    @on_init
    async def _initialize_children(self) -> None:
        """Initialize all children. This is done via the @on_init hook to ensure that the children
        initialize along with the parent hooks, and not after the parent hooks, which would cause
        a race condition.
        """
        for child in self._children:
            await child.initialize()

    @on_start
    async def _start_children(self) -> None:
        """Start all children. This is done via the @on_start hook to ensure that the children
        start along with the parent hooks, and not after the parent hooks, which would cause
        a race condition.
        """
        for child in self._children:
            await child.start()

    @on_stop
    async def _stop_children(self) -> None:
        """Stop all children. This is done via the @on_stop hook to ensure that the children
        are stopped along with the parent hooks, and not after the parent hooks, which would cause
        a race condition.
        """
        for child in reversed(self._children):
            await child.stop()

    def __str__(self) -> str:
        return f"{self.__class__.__name__} (id={self.id})"

    def __repr__(self) -> str:
        return f"<{self.__class__.__qualname__} {self.id} (state={self.state})>"

is_running property

Whether the lifecycle's current state is LifecycleState.RUNNING.

stop_requested property writable

Whether the lifecycle has been requested to stop.

__init__(id=None, **kwargs)

Parameters:

Name Type Description Default
id str | None

The id of the lifecycle. If not provided, a random uuid will be generated.

None
Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(self, id: str | None = None, **kwargs) -> None:
    """
    Args:
        id: The id of the lifecycle. If not provided, a random uuid will be generated.
    """
    self.id = id or f"{self.__class__.__name__}_{uuid.uuid4().hex[:8]}"
    self._state = LifecycleState.CREATED
    self.initialized_event = asyncio.Event()
    self.started_event = asyncio.Event()
    self._stop_requested_event = asyncio.Event()
    self.stopped_event = asyncio.Event()  # set on stop or failure
    self._children: list[AIPerfLifecycleProtocol] = []
    if "logger_name" not in kwargs:
        kwargs["logger_name"] = self.id
    super().__init__(**kwargs)

attach_child_lifecycle(child)

Attach a child lifecycle to manage. This child will now have its lifecycle managed and controlled by this lifecycle. Common use cases are having a Service be a parent lifecycle, and having supporting components such as streaming post processors, progress reporters, etc. be children.

Children will be called in the order they were attached for initialize and start, and in reverse order for stop.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def attach_child_lifecycle(self, child: AIPerfLifecycleProtocol) -> None:
    """Attach a child lifecycle to manage. This child will now have its lifecycle managed and
    controlled by this lifecycle. Common use cases are having a Service be a parent lifecycle,
    and having supporting components such as streaming post processors, progress reporters, etc. be children.

    Children will be called in the order they were attached for initialize and start,
    and in reverse order for stop.
    """
    if self.state != LifecycleState.CREATED:
        raise InvalidStateError(
            f"Cannot attach child {child} to {self} in state {self.state}. "
            "Please attach children before initializing or starting the lifecycle."
        )
    self._children.append(child)

initialize() async

Initialize the lifecycle and run the @on_init hooks.

NOTE: It is generally discouraged from overriding this method. Instead, use the @on_init hook to handle your own initialization logic.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
async def initialize(self) -> None:
    """Initialize the lifecycle and run the @on_init hooks.

    NOTE: It is generally discouraged from overriding this method.
    Instead, use the @on_init hook to handle your own initialization logic.
    """
    if self.state in (
        LifecycleState.INITIALIZING,
        LifecycleState.INITIALIZED,
        LifecycleState.STARTING,
        LifecycleState.RUNNING,
    ):
        self.debug(
            lambda: f"Ignoring initialize request for {self} in state {self.state}"
        )
        return

    if self.state != LifecycleState.CREATED:
        raise InvalidStateError(
            f"Cannot initialize from state {self.state} for {self}"
        )

    await self._execute_state_transition(
        LifecycleState.INITIALIZING,
        LifecycleState.INITIALIZED,
        AIPerfHook.ON_INIT,
        self.initialized_event,
    )

initialize_and_start() async

Initialize and start the lifecycle. This is a convenience method that calls initialize and start in sequence.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
182
183
184
185
async def initialize_and_start(self) -> None:
    """Initialize and start the lifecycle. This is a convenience method that calls `initialize` and `start` in sequence."""
    await self.initialize()
    await self.start()

start() async

Start the lifecycle and run the @on_start hooks.

NOTE: It is generally discouraged from overriding this method. Instead, use the @on_start hook to handle your own starting logic.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
async def start(self) -> None:
    """Start the lifecycle and run the @on_start hooks.

    NOTE: It is generally discouraged from overriding this method.
    Instead, use the @on_start hook to handle your own starting logic.
    """
    if self.state in (
        LifecycleState.STARTING,
        LifecycleState.RUNNING,
    ):
        self.debug(
            lambda: f"Ignoring start request for {self} in state {self.state}"
        )
        return

    if self.state != LifecycleState.INITIALIZED:
        raise InvalidStateError(f"Cannot start from state {self.state} for {self}")

    await self._execute_state_transition(
        LifecycleState.STARTING,
        LifecycleState.RUNNING,
        AIPerfHook.ON_START,
        self.started_event,
    )

stop() async

Stop the lifecycle and run the @on_stop hooks.

NOTE: It is generally discouraged from overriding this method. Instead, use the @on_stop hook to handle your own stopping logic.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
async def stop(self) -> None:
    """Stop the lifecycle and run the @on_stop hooks.

    NOTE: It is generally discouraged from overriding this method.
    Instead, use the @on_stop hook to handle your own stopping logic.
    """
    if self.stop_requested:
        self.debug(
            lambda: f"Ignoring stop request for {self} in state {self.state}"
        )
        return

    self.stop_requested = True
    await self._execute_state_transition(
        LifecycleState.STOPPING,
        LifecycleState.STOPPED,
        AIPerfHook.ON_STOP,
        self.stopped_event,
        reverse=True,  # run the stop hooks in reverse order
    )

aiperf.common.mixins.aiperf_logger_mixin

AIPerfLoggerMixin

Bases: BaseMixin

Mixin to provide lazy evaluated logging for f-strings.

This mixin provides a logger with lazy evaluation support for f-strings, and direct log functions for all standard and custom logging levels.

see :class:AIPerfLogger for more details.

Usage

class MyClass(AIPerfLoggerMixin): def init(self): super().init() self.trace(lambda: f"Processing {item} of {count} ({item / count * 100}% complete)") self.info("Simple string message") self.debug(lambda i=i: f"Binding loop variable: {i}") self.warning("Warning message: %s", "legacy support") self.success("Benchmark completed successfully") self.notice("Warmup has completed") self.exception(f"Direct f-string usage: {e}")

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@implements_protocol(AIPerfLoggerProtocol)
class AIPerfLoggerMixin(BaseMixin):
    """Mixin to provide lazy evaluated logging for f-strings.

    This mixin provides a logger with lazy evaluation support for f-strings,
    and direct log functions for all standard and custom logging levels.

    see :class:`AIPerfLogger` for more details.

    Usage:
        class MyClass(AIPerfLoggerMixin):
            def __init__(self):
                super().__init__()
                self.trace(lambda: f"Processing {item} of {count} ({item / count * 100}% complete)")
                self.info("Simple string message")
                self.debug(lambda i=i: f"Binding loop variable: {i}")
                self.warning("Warning message: %s", "legacy support")
                self.success("Benchmark completed successfully")
                self.notice("Warmup has completed")
                self.exception(f"Direct f-string usage: {e}")
    """

    def __init__(self, logger_name: str | None = None, **kwargs) -> None:
        self.logger = AIPerfLogger(logger_name or self.__class__.__name__)
        self._log = self.logger._log
        self.is_enabled_for = self.logger._logger.isEnabledFor
        # Directly set the trace_or_debug method to the logger's trace_or_debug method to avoid
        # the overhead of the extra call stack.
        self.trace_or_debug = self.logger.trace_or_debug
        super().__init__(**kwargs)

    @property
    def is_debug_enabled(self) -> bool:
        return self.is_enabled_for(_DEBUG)

    @property
    def is_trace_enabled(self) -> bool:
        return self.is_enabled_for(_TRACE)

    def log(
        self, level: int, message: str | Callable[..., str], *args, **kwargs
    ) -> None:
        """Log a message at a specified level with lazy evaluation."""
        if self.is_enabled_for(level):
            self._log(level, message, *args, **kwargs)

    def trace(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a trace message with lazy evaluation."""
        if self.is_enabled_for(_TRACE):
            self._log(_TRACE, message, *args, **kwargs)

    def debug(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a debug message with lazy evaluation."""
        if self.is_enabled_for(_DEBUG):
            self._log(_DEBUG, message, *args, **kwargs)

    def info(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an info message with lazy evaluation."""
        if self.is_enabled_for(_INFO):
            self._log(_INFO, message, *args, **kwargs)

    def notice(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a notice message with lazy evaluation."""
        if self.is_enabled_for(_NOTICE):
            self._log(_NOTICE, message, *args, **kwargs)

    def warning(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a warning message with lazy evaluation."""
        if self.is_enabled_for(_WARNING):
            self._log(_WARNING, message, *args, **kwargs)

    def success(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a success message with lazy evaluation."""
        if self.is_enabled_for(_SUCCESS):
            self._log(_SUCCESS, message, *args, **kwargs)

    def error(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an error message with lazy evaluation."""
        if self.is_enabled_for(_ERROR):
            self._log(_ERROR, message, *args, **kwargs)

    def exception(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an exception message with lazy evaluation."""
        if self.is_enabled_for(_ERROR):
            self._log(_ERROR, message, *args, exc_info=True, **kwargs)

    def critical(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a critical message with lazy evaluation."""
        if self.is_enabled_for(_CRITICAL):
            self._log(_CRITICAL, message, *args, **kwargs)

critical(message, *args, **kwargs)

Log a critical message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
109
110
111
112
def critical(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a critical message with lazy evaluation."""
    if self.is_enabled_for(_CRITICAL):
        self._log(_CRITICAL, message, *args, **kwargs)

debug(message, *args, **kwargs)

Log a debug message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
74
75
76
77
def debug(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a debug message with lazy evaluation."""
    if self.is_enabled_for(_DEBUG):
        self._log(_DEBUG, message, *args, **kwargs)

error(message, *args, **kwargs)

Log an error message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
 99
100
101
102
def error(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an error message with lazy evaluation."""
    if self.is_enabled_for(_ERROR):
        self._log(_ERROR, message, *args, **kwargs)

exception(message, *args, **kwargs)

Log an exception message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
104
105
106
107
def exception(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an exception message with lazy evaluation."""
    if self.is_enabled_for(_ERROR):
        self._log(_ERROR, message, *args, exc_info=True, **kwargs)

info(message, *args, **kwargs)

Log an info message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
79
80
81
82
def info(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an info message with lazy evaluation."""
    if self.is_enabled_for(_INFO):
        self._log(_INFO, message, *args, **kwargs)

log(level, message, *args, **kwargs)

Log a message at a specified level with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
62
63
64
65
66
67
def log(
    self, level: int, message: str | Callable[..., str], *args, **kwargs
) -> None:
    """Log a message at a specified level with lazy evaluation."""
    if self.is_enabled_for(level):
        self._log(level, message, *args, **kwargs)

notice(message, *args, **kwargs)

Log a notice message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
84
85
86
87
def notice(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a notice message with lazy evaluation."""
    if self.is_enabled_for(_NOTICE):
        self._log(_NOTICE, message, *args, **kwargs)

success(message, *args, **kwargs)

Log a success message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
94
95
96
97
def success(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a success message with lazy evaluation."""
    if self.is_enabled_for(_SUCCESS):
        self._log(_SUCCESS, message, *args, **kwargs)

trace(message, *args, **kwargs)

Log a trace message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
69
70
71
72
def trace(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a trace message with lazy evaluation."""
    if self.is_enabled_for(_TRACE):
        self._log(_TRACE, message, *args, **kwargs)

warning(message, *args, **kwargs)

Log a warning message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
89
90
91
92
def warning(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a warning message with lazy evaluation."""
    if self.is_enabled_for(_WARNING):
        self._log(_WARNING, message, *args, **kwargs)

aiperf.common.mixins.base_mixin

BaseMixin

Base mixin class.

This Mixin creates a contract that Mixins should always pass **kwargs to super().init, regardless of whether they extend another mixin or not.

This will ensure that the BaseMixin is the last mixin to have its init method called, which means that all other mixins will have a proper chain of init methods with the correct arguments and no accidental broken inheritance.

Source code in aiperf/common/mixins/base_mixin.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class BaseMixin:
    """Base mixin class.

    This Mixin creates a contract that Mixins should always pass **kwargs to
    super().__init__, regardless of whether they extend another mixin or not.

    This will ensure that the BaseMixin is the last mixin to have its __init__
    method called, which means that all other mixins will have a proper
    chain of __init__ methods with the correct arguments and no accidental
    broken inheritance.
    """

    def __init__(self, **kwargs):
        # object.__init__ does not take any arguments
        super().__init__()

aiperf.common.mixins.command_handler_mixin

CommandHandlerMixin

Bases: MessageBusClientMixin, ABC

Mixin to provide command handling functionality to a service.

This mixin is used by the BaseService class, and is not intended to be used directly.

Source code in aiperf/common/mixins/command_handler_mixin.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@provides_hooks(AIPerfHook.ON_COMMAND)
class CommandHandlerMixin(MessageBusClientMixin, ABC):
    """Mixin to provide command handling functionality to a service.

    This mixin is used by the BaseService class, and is not intended to be used directly.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str,
        **kwargs,
    ) -> None:
        self.service_config = service_config
        self.user_config = user_config
        self.service_id = service_id

        # Keep track of command IDs that have been processed.
        # This is used to avoid processing duplicate command messages.
        self._processed_command_ids: set[str] = set()

        # Keep track of futures for single response commands.
        # This is used to wait for the response from a single service.
        self._single_response_futures: dict[str, asyncio.Future[CommandResponse]] = {}

        # Keep track of futures for multi response commands.
        # This is used to wait for the responses from multiple services.
        self._multi_response_futures: dict[
            str, dict[str, asyncio.Future[CommandResponse]]
        ] = {}

        super().__init__(
            service_config=self.service_config,
            user_config=self.user_config,
            **kwargs,
        )

    @on_message(
        lambda self: {
            # Subscribe to all broadcast command messages.
            MessageType.COMMAND,
            # Subscribe to all command messages for this specific service type.
            f"{MessageType.COMMAND}.{self.service_type}",
            # Subscribe to all command messages for this specific service ID.
            f"{MessageType.COMMAND}.{self.service_id}",
        }
    )
    async def _process_command_message(self, message: CommandMessage) -> None:
        """
        Process a command message received from the controller or another service, and forward it to the appropriate handler.
        Wait for the handler to complete and publish the response, or handle the error and publish the failure response.
        """
        self.debug(lambda: f"Received command message: {message}")
        if message.command_id in self._processed_command_ids:
            self.debug(
                lambda: f"Received duplicate command message: {message}. Ignoring."
            )
            # If we receive a duplicate command message, we just send an acknowledged response.
            await self._publish_command_acknowledged_response(message)
            return

        self._processed_command_ids.add(message.command_id)

        if message.service_id == self.service_id:
            # In the case of a broadcast command, you will receive a command message from yourself.
            # We ignore these messages.
            self.debug(
                lambda: f"Received broadcast command message from self: {message}. Ignoring."
            )
            return

        # Go through the hooks and find the first one that matches the command type.
        # Currently, we only support a single handler per command type, so we break out of the loop after the first one.
        # The reason for this is because we are sending the result of the handler function back to the original service that sent the command.
        # If there were multiple handlers, we would need to handle multiple responses, partial errors, etc.
        # TODO: Do we want/need to add support for multiple handlers per command type?
        for hook in self.get_hooks(AIPerfHook.ON_COMMAND):
            if isinstance(hook.params, Iterable) and message.command in hook.params:
                await self._execute_command_hook(message, hook)
                # Only one handler per command type, so return after the first handler.
                return

        # If we reach here, no handler was found for the command, so we publish an unhandled response.
        await self._publish_command_unhandled_response(message)

    async def _execute_command_hook(self, message: CommandMessage, hook: Hook) -> None:
        """Execute a command hook.
        This is the internal function that wraps calling a hook function and publishing the response
        based on the return value of the hook function.
        """
        try:
            result = await hook.func(message)
            if result is None:
                # If there is no data to send back, just send an acknowledged response.
                await self._publish_command_acknowledged_response(message)
                return
            await self._publish_command_success_response(message, result)
        except Exception as e:
            self.exception(
                f"Failed to handle command {message.command} with hook {hook}: {e}"
            )
            await self._publish_command_error_response(
                message, ErrorDetails.from_exception(e)
            )

    async def _publish_command_acknowledged_response(
        self, message: CommandMessage
    ) -> None:
        """Publish a command acknowledged response to a command message."""
        await self.publish(
            CommandAcknowledgedResponse.from_command_message(message, self.service_id)
        )

    async def _publish_command_success_response(
        self, message: CommandMessage, result: Any
    ) -> None:
        """Publish a command success response to a command message."""
        await self.publish(
            CommandSuccessResponse.from_command_message(
                message, self.service_id, result
            )
        )

    async def _publish_command_error_response(
        self, message: CommandMessage, error: ErrorDetails
    ) -> None:
        """Publish a command error response to a command message."""
        await self.publish(
            CommandErrorResponse.from_command_message(message, self.service_id, error)
        )

    async def _publish_command_unhandled_response(
        self, message: CommandMessage
    ) -> None:
        """Publish a command unhandled response to a command message."""
        await self.publish(
            CommandUnhandledResponse.from_command_message(message, self.service_id)
        )

    async def send_command_and_wait_for_response(
        self, message: CommandMessage, timeout: float = DEFAULT_COMMAND_RESPONSE_TIMEOUT
    ) -> CommandResponse | ErrorDetails:
        """Send a single command message to a single service and wait for the response.
        This is useful communicating directly with a single service.
        """
        # Create a future that we can asynchronously wait for the response.
        future = asyncio.Future[CommandResponse]()
        self._single_response_futures[message.command_id] = future
        await self.publish(message)
        try:
            # Wait for the response future to be set by the command response message handler.
            return await asyncio.wait_for(future, timeout=timeout)
        except asyncio.TimeoutError as e:
            return ErrorDetails.from_exception(e)
        finally:
            future.cancel()
            del self._single_response_futures[message.command_id]

    async def send_command_and_wait_for_all_responses(
        self,
        command: CommandMessage,
        service_ids: list[str],
        timeout: float = DEFAULT_COMMAND_RESPONSE_TIMEOUT,
    ) -> list[CommandResponse | ErrorDetails]:
        """Broadcast a command message to multiple services and wait for the responses from all of the specified services.
        This is useful for the system controller to send one command but wait for all services to respond.
        """
        # Create a future to track the response for each service ID.
        self._multi_response_futures[command.command_id] = {
            service_id: asyncio.Future[CommandResponse]() for service_id in service_ids
        }
        # Send the command to all services based on the target service ID and target service type.
        await self.publish(command)
        try:
            # Wait for all the responses to come in.
            return await asyncio.wait_for(
                asyncio.gather(
                    *[
                        future
                        for future in self._multi_response_futures[
                            command.command_id
                        ].values()
                    ]
                ),
                timeout=timeout,
            )
        except asyncio.TimeoutError as e:
            return [ErrorDetails.from_exception(e) for _ in range(len(service_ids))]
        finally:
            # Clean up the response futures.
            for future in self._multi_response_futures[command.command_id].values():
                future.cancel()
            del self._multi_response_futures[command.command_id]

    @on_message(
        lambda self: {
            # NOTE: Command responses are only ever sent to the original service that sent the command,
            #       so we only need to subscribe to the service ID specific topic.
            f"{MessageType.COMMAND_RESPONSE}.{self.service_id}",
        }
    )
    async def _process_command_response_message(self, message: CommandResponse) -> None:
        """
        Process a command response message received from a service. This function is called whenever
        a command response is received, and we use it to set the result of the future for the command ID.
        This will alert the the task that is waiting for the response to continue.
        """
        self.debug(lambda: f"Received command response message: {message}")

        # If the command ID is in the single response futures, we set the result of the future.
        # This will alert the the task that is waiting for the response to continue.
        if message.command_id in self._single_response_futures:
            self._single_response_futures[message.command_id].set_result(message)
            return

        # If the command ID is in the multi response futures, we set the result of the future for the service ID of the sender.
        # This will alert the the task that is waiting for the response to continue.
        if message.command_id in self._multi_response_futures:
            if message.service_id in self._multi_response_futures[message.command_id]:
                self._multi_response_futures[message.command_id][
                    message.service_id
                ].set_result(message)
            else:
                self.warning(
                    f"Received command response for service we were not expecting: {message.service_id}. Ignoring."
                )
            return

        # If we reach here, we received a command response that we were not tracking. It is
        # safe to ignore.
        self.debug(
            lambda: f"Received command response for untracked command: {message}. Ignoring."
        )

send_command_and_wait_for_all_responses(command, service_ids, timeout=DEFAULT_COMMAND_RESPONSE_TIMEOUT) async

Broadcast a command message to multiple services and wait for the responses from all of the specified services. This is useful for the system controller to send one command but wait for all services to respond.

Source code in aiperf/common/mixins/command_handler_mixin.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
async def send_command_and_wait_for_all_responses(
    self,
    command: CommandMessage,
    service_ids: list[str],
    timeout: float = DEFAULT_COMMAND_RESPONSE_TIMEOUT,
) -> list[CommandResponse | ErrorDetails]:
    """Broadcast a command message to multiple services and wait for the responses from all of the specified services.
    This is useful for the system controller to send one command but wait for all services to respond.
    """
    # Create a future to track the response for each service ID.
    self._multi_response_futures[command.command_id] = {
        service_id: asyncio.Future[CommandResponse]() for service_id in service_ids
    }
    # Send the command to all services based on the target service ID and target service type.
    await self.publish(command)
    try:
        # Wait for all the responses to come in.
        return await asyncio.wait_for(
            asyncio.gather(
                *[
                    future
                    for future in self._multi_response_futures[
                        command.command_id
                    ].values()
                ]
            ),
            timeout=timeout,
        )
    except asyncio.TimeoutError as e:
        return [ErrorDetails.from_exception(e) for _ in range(len(service_ids))]
    finally:
        # Clean up the response futures.
        for future in self._multi_response_futures[command.command_id].values():
            future.cancel()
        del self._multi_response_futures[command.command_id]

send_command_and_wait_for_response(message, timeout=DEFAULT_COMMAND_RESPONSE_TIMEOUT) async

Send a single command message to a single service and wait for the response. This is useful communicating directly with a single service.

Source code in aiperf/common/mixins/command_handler_mixin.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
async def send_command_and_wait_for_response(
    self, message: CommandMessage, timeout: float = DEFAULT_COMMAND_RESPONSE_TIMEOUT
) -> CommandResponse | ErrorDetails:
    """Send a single command message to a single service and wait for the response.
    This is useful communicating directly with a single service.
    """
    # Create a future that we can asynchronously wait for the response.
    future = asyncio.Future[CommandResponse]()
    self._single_response_futures[message.command_id] = future
    await self.publish(message)
    try:
        # Wait for the response future to be set by the command response message handler.
        return await asyncio.wait_for(future, timeout=timeout)
    except asyncio.TimeoutError as e:
        return ErrorDetails.from_exception(e)
    finally:
        future.cancel()
        del self._single_response_futures[message.command_id]

aiperf.common.mixins.communication_mixin

CommunicationMixin

Bases: AIPerfLifecycleMixin, ABC

Mixin to provide access to a CommunicationProtocol instance. This mixin should be inherited by any mixin that needs access to the communication layer to create Communication clients.

Source code in aiperf/common/mixins/communication_mixin.py
12
13
14
15
16
17
18
19
20
21
22
23
24
class CommunicationMixin(AIPerfLifecycleMixin, ABC):
    """Mixin to provide access to a CommunicationProtocol instance. This mixin should be inherited
    by any mixin that needs access to the communication layer to create Communication clients.
    """

    def __init__(self, service_config: ServiceConfig, **kwargs) -> None:
        super().__init__(service_config=service_config, **kwargs)
        self.service_config = service_config
        self.comms: CommunicationProtocol = CommunicationFactory.get_or_create_instance(
            self.service_config.comm_backend,
            config=self.service_config.comm_config,
        )
        self.attach_child_lifecycle(self.comms)

aiperf.common.mixins.hooks_mixin

HooksMixin

Bases: AIPerfLoggerMixin

Mixin for a class to be able to provide hooks to its subclasses, and to be able to run them. A "hook" is a function that is decorated with a hook type (AIPerfHook), and optional parameters.

In order to provide hooks, a class MUST use the @provides_hooks decorator to declare the hook types it provides. Only list hook types that you call get_hooks or run_hooks on, to get or run the functions that are decorated with those hook types.

Provided hooks are recursively inherited by subclasses, so if a base class provides a hook, all subclasses will also provide that hook (without having to explicitly declare it, or call get_hooks or run_hooks). In fact, you typically should not get or run hooks from the base class, as this may lead to calling hooks twice.

Hooks are registered in the order they are defined within the same class from top to bottom, and each class's hooks are inspected starting with hooks defined in the lowest level of base classes, moving up to the highest subclass.

IMPORTANT: - Hook decorated methods from one class can be named the same as methods in their base classes, and BOTH will be registered. Meaning if class A and class B both have a method named _initialize, which is decorated with @on_init, and class B inherits from class A, then both _initialize methods will be registered as hooks, and will be run in the order A._initialize, then B._initialize. This is done without requiring the user to call super()._initialize in the subclass, as the base class hook will be run automatically. However, the caveat is that it is not possible to disable the hook from the base class without extra work, and if the user does accidentally call super()._initialize in the subclass, the base class hook may be called twice.

Source code in aiperf/common/mixins/hooks_mixin.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
@implements_protocol(HooksProtocol)
class HooksMixin(AIPerfLoggerMixin):
    """Mixin for a class to be able to provide hooks to its subclasses, and to be able to run them. A "hook" is a function
    that is decorated with a hook type (AIPerfHook), and optional parameters.

    In order to provide hooks, a class MUST use the `@provides_hooks` decorator to declare the hook types it provides.
    Only list hook types that you call `get_hooks` or `run_hooks` on, to get or run the functions that are decorated
    with those hook types.

    Provided hooks are recursively inherited by subclasses, so if a base class provides a hook,
    all subclasses will also provide that hook (without having to explicitly declare it, or call `get_hooks` or `run_hooks`).
    In fact, you typically should not get or run hooks from the base class, as this may lead to calling hooks twice.

    Hooks are registered in the order they are defined within the same class from top to bottom, and each class's hooks
    are inspected starting with hooks defined in the lowest level of base classes, moving up to the highest subclass.

    IMPORTANT:
    - Hook decorated methods from one class can be named the same as methods in their base classes, and BOTH will be registered.
    Meaning if class A and class B both have a method named `_initialize`, which is decorated with `@on_init`, and class B inherits from class A,
    then both `_initialize` methods will be registered as hooks, and will be run in the order A._initialize, then B._initialize.
    This is done without requiring the user to call `super()._initialize` in the subclass, as the base class hook will be run automatically.
    However, the caveat is that it is not possible to disable the hook from the base class without extra work, and if the user does accidentally
    call `super()._initialize` in the subclass, the base class hook may be called twice.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self._provided_hook_types: set[HookType] = set()

        self._hooks: dict[HookType, list[Hook]] = {}
        # Go through the MRO in reverse order to ensure that the hooks are
        # registered in the correct order (base classes first, then subclasses).
        for cls in reversed(self.__class__.__mro__):
            if hasattr(cls, HookAttrs.PROVIDES_HOOKS):
                # As we find base classes that provide hooks, we add them to the
                # set of provided hook types, which is used for validation.
                self._provided_hook_types.update(getattr(cls, HookAttrs.PROVIDES_HOOKS))

            # Go through the class's methods to find the hooks.
            for method in cls.__dict__.values():
                if not callable(method):
                    continue

                # If the method has the AIPERF_HOOK_TYPE attribute, it is a hook.
                if hasattr(method, HookAttrs.HOOK_TYPE):
                    method_hook_type = getattr(method, HookAttrs.HOOK_TYPE)
                    # If the hook type is not provided by any base class, it is an error.
                    # This is to ensure that the hook is only registered if it is provided by a base class.
                    # This is to avoid the case where a developer accidentally uses a hook that is not provided by a base class.
                    if method_hook_type not in self._provided_hook_types:
                        raise UnsupportedHookError(
                            f"Hook {method_hook_type} is not provided by any base class of {self.__class__.__name__}. "
                            f"(Provided Hooks: {[f'{hook_type}' for hook_type in self._provided_hook_types]})"
                        )

                    # Bind the method to the instance ("self"), extract the hook parameters,
                    # and add it to the hooks dictionary.
                    bound_method = method.__get__(self)
                    self._hooks.setdefault(method_hook_type, []).append(
                        Hook(
                            func=bound_method,
                            params=getattr(method, HookAttrs.HOOK_PARAMS, None),
                        ),
                    )

        self.debug(
            lambda: f"Provided hook types: {self._provided_hook_types} for {self.__class__.__name__}"
        )

    def get_hooks(self, *hook_types: HookType, reverse: bool = False) -> list[Hook]:
        """Get the hooks that are defined by the class for the given hook type(s), optionally reversed.
        This will return a list of Hook objects that can be inspected for their type and parameters,
        and optionally called."""
        hooks = [
            hook
            for hook_type, hooks in self._hooks.items()
            if not hook_types or hook_type in hook_types
            for hook in hooks
        ]
        if reverse:
            hooks.reverse()
        return hooks

    def for_each_hook_param(
        self,
        *hook_types: HookType,
        self_obj: Any,
        param_type: AnyT,
        lambda_func: Callable[[Hook, AnyT], None],
        reverse: bool = False,
    ) -> None:
        """Iterate over the hooks for the given hook type(s), optionally reversed.
        If a lambda_func is provided, it will be called for each parameter of the hook,
        and the hook and parameter will be passed as arguments.

        Args:
            hook_types: The hook types to iterate over.
            self_obj: The object to pass to the lambda_func.
            param_type: The type of the parameter to pass to the lambda_func (for validation).
            lambda_func: The function to call for each hook.
            reverse: Whether to iterate over the hooks in reverse order.
        """
        for hook in self.get_hooks(*hook_types, reverse=reverse):
            # in case the hook params are a callable, we need to resolve them to get the actual params
            params = hook.resolve_params(self_obj)
            if not isinstance(params, Iterable):
                raise ValueError(
                    f"Invalid hook params: {params}. Expected Iterable but got {type(params)}"
                )
            for param in params:
                self.trace(
                    lambda param=param,
                    type=param_type: f"param: {param}, param_type: {type}"
                )
                if not isinstance(param, param_type):
                    raise ValueError(
                        f"Invalid hook param: {param}. Expected {param_type} but got {type(param)}"
                    )
                # Call the lambda_func for each parameter of each hook.
                lambda_func(hook, param)

    async def run_hooks(
        self, *hook_types: HookType, reverse: bool = False, **kwargs
    ) -> None:
        """Run the hooks for the given hook type, waiting for each hook to complete before running the next one.
        Hooks are run in the order they are defined by the class, starting with hooks defined in the lowest level
        of base classes, moving up to the top level class. If more than one hook type is provided, the hooks
        from each level of classes will be run in the order of the hook types provided.

        If reverse is True, the hooks will be run in reverse order. This is useful for stop/cleanup hooks, where you
        want to start with the children and ending with the parent.

        The kwargs are passed through as keyword arguments to each hook.
        """
        exceptions: list[Exception] = []
        for hook in self.get_hooks(*hook_types, reverse=reverse):
            self.debug(lambda hook=hook: f"Running hook: {hook!r}")
            try:
                await hook(**kwargs)
            except Exception as e:
                exceptions.append(e)
                self.exception(
                    f"Error running {hook!r} hook for {self.__class__.__name__}: {e}"
                )
        if exceptions:
            raise AIPerfMultiError(
                f"Errors running {hook_types} hooks for {self.__class__.__name__}",
                exceptions,
            )

for_each_hook_param(*hook_types, self_obj, param_type, lambda_func, reverse=False)

Iterate over the hooks for the given hook type(s), optionally reversed. If a lambda_func is provided, it will be called for each parameter of the hook, and the hook and parameter will be passed as arguments.

Parameters:

Name Type Description Default
hook_types HookType

The hook types to iterate over.

()
self_obj Any

The object to pass to the lambda_func.

required
param_type AnyT

The type of the parameter to pass to the lambda_func (for validation).

required
lambda_func Callable[[Hook, AnyT], None]

The function to call for each hook.

required
reverse bool

Whether to iterate over the hooks in reverse order.

False
Source code in aiperf/common/mixins/hooks_mixin.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def for_each_hook_param(
    self,
    *hook_types: HookType,
    self_obj: Any,
    param_type: AnyT,
    lambda_func: Callable[[Hook, AnyT], None],
    reverse: bool = False,
) -> None:
    """Iterate over the hooks for the given hook type(s), optionally reversed.
    If a lambda_func is provided, it will be called for each parameter of the hook,
    and the hook and parameter will be passed as arguments.

    Args:
        hook_types: The hook types to iterate over.
        self_obj: The object to pass to the lambda_func.
        param_type: The type of the parameter to pass to the lambda_func (for validation).
        lambda_func: The function to call for each hook.
        reverse: Whether to iterate over the hooks in reverse order.
    """
    for hook in self.get_hooks(*hook_types, reverse=reverse):
        # in case the hook params are a callable, we need to resolve them to get the actual params
        params = hook.resolve_params(self_obj)
        if not isinstance(params, Iterable):
            raise ValueError(
                f"Invalid hook params: {params}. Expected Iterable but got {type(params)}"
            )
        for param in params:
            self.trace(
                lambda param=param,
                type=param_type: f"param: {param}, param_type: {type}"
            )
            if not isinstance(param, param_type):
                raise ValueError(
                    f"Invalid hook param: {param}. Expected {param_type} but got {type(param)}"
                )
            # Call the lambda_func for each parameter of each hook.
            lambda_func(hook, param)

get_hooks(*hook_types, reverse=False)

Get the hooks that are defined by the class for the given hook type(s), optionally reversed. This will return a list of Hook objects that can be inspected for their type and parameters, and optionally called.

Source code in aiperf/common/mixins/hooks_mixin.py
85
86
87
88
89
90
91
92
93
94
95
96
97
def get_hooks(self, *hook_types: HookType, reverse: bool = False) -> list[Hook]:
    """Get the hooks that are defined by the class for the given hook type(s), optionally reversed.
    This will return a list of Hook objects that can be inspected for their type and parameters,
    and optionally called."""
    hooks = [
        hook
        for hook_type, hooks in self._hooks.items()
        if not hook_types or hook_type in hook_types
        for hook in hooks
    ]
    if reverse:
        hooks.reverse()
    return hooks

run_hooks(*hook_types, reverse=False, **kwargs) async

Run the hooks for the given hook type, waiting for each hook to complete before running the next one. Hooks are run in the order they are defined by the class, starting with hooks defined in the lowest level of base classes, moving up to the top level class. If more than one hook type is provided, the hooks from each level of classes will be run in the order of the hook types provided.

If reverse is True, the hooks will be run in reverse order. This is useful for stop/cleanup hooks, where you want to start with the children and ending with the parent.

The kwargs are passed through as keyword arguments to each hook.

Source code in aiperf/common/mixins/hooks_mixin.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
async def run_hooks(
    self, *hook_types: HookType, reverse: bool = False, **kwargs
) -> None:
    """Run the hooks for the given hook type, waiting for each hook to complete before running the next one.
    Hooks are run in the order they are defined by the class, starting with hooks defined in the lowest level
    of base classes, moving up to the top level class. If more than one hook type is provided, the hooks
    from each level of classes will be run in the order of the hook types provided.

    If reverse is True, the hooks will be run in reverse order. This is useful for stop/cleanup hooks, where you
    want to start with the children and ending with the parent.

    The kwargs are passed through as keyword arguments to each hook.
    """
    exceptions: list[Exception] = []
    for hook in self.get_hooks(*hook_types, reverse=reverse):
        self.debug(lambda hook=hook: f"Running hook: {hook!r}")
        try:
            await hook(**kwargs)
        except Exception as e:
            exceptions.append(e)
            self.exception(
                f"Error running {hook!r} hook for {self.__class__.__name__}: {e}"
            )
    if exceptions:
        raise AIPerfMultiError(
            f"Errors running {hook_types} hooks for {self.__class__.__name__}",
            exceptions,
        )

aiperf.common.mixins.message_bus_mixin

MessageBusClientMixin

Bases: CommunicationMixin, ABC

Mixin to provide message bus clients (pub and sub)for AIPerf components, as well as a hook to handle messages: @on_message.

Source code in aiperf/common/mixins/message_bus_mixin.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@provides_hooks(AIPerfHook.ON_MESSAGE)
@implements_protocol(MessageBusClientProtocol)
class MessageBusClientMixin(CommunicationMixin, ABC):
    """Mixin to provide message bus clients (pub and sub)for AIPerf components, as well as
    a hook to handle messages: @on_message."""

    def __init__(self, service_config: ServiceConfig, **kwargs) -> None:
        super().__init__(service_config=service_config, **kwargs)
        # NOTE: The communication base class will automatically manage the pub/sub clients' lifecycle.
        self.sub_client = self.comms.create_sub_client(
            CommAddress.EVENT_BUS_PROXY_BACKEND
        )
        self.pub_client = self.comms.create_pub_client(
            CommAddress.EVENT_BUS_PROXY_FRONTEND
        )
        self._connection_probe_event = asyncio.Event()

    @on_init
    async def _setup_on_message_hooks(self) -> None:
        """Send subscription requests for all @on_message hook decorators."""
        subscription_map: MessageCallbackMapT = {}

        def _add_to_subscription_map(hook: Hook, message_type: MessageTypeT) -> None:
            """
            This function is called for every message_type parameter of every @on_message hook.
            We use this to build a map of message types to callbacks, which is then used to call
            subscribe_all for efficiency
            """
            self.debug(
                lambda: f"Adding subscription for message type: '{message_type}' for hook: {hook}"
            )
            subscription_map.setdefault(message_type, []).append(hook.func)

        # For each @on_message hook, add each message type to the subscription map.
        self.for_each_hook_param(
            AIPerfHook.ON_MESSAGE,
            self_obj=self,
            param_type=MessageTypeT,
            lambda_func=_add_to_subscription_map,
        )
        self.debug(lambda: f"Subscribing to {len(subscription_map)} topics")
        await self.sub_client.subscribe_all(subscription_map)

        # Subscribe to the connection probe last, to ensure the other subscriptions have been
        # subscribed to before the connection probe is received.
        await self.sub_client.subscribe(
            # NOTE: It is important to use `self.id` here, as not all message bus clients are services
            f"{MessageType.CONNECTION_PROBE}.{self.id}",
            self._process_connection_probe_message,
        )

    @on_start
    async def _wait_for_successful_probe(self) -> None:
        """Send connection probe messages until a successful probe response is received."""
        self.debug(lambda: f"Waiting for connection probe message for {self.id}")

        async def _probe_loop() -> None:
            while not self.stop_requested:
                try:
                    await asyncio.wait_for(
                        self._probe_and_wait_for_response(),
                        timeout=DEFAULT_CONNECTION_PROBE_INTERVAL,
                    )
                    break
                except asyncio.TimeoutError:
                    self.debug(
                        "Timeout waiting for connection probe message, sending another probe"
                    )
                    await yield_to_event_loop()

        await asyncio.wait_for(_probe_loop(), timeout=DEFAULT_CONNECTION_PROBE_TIMEOUT)

    async def _process_connection_probe_message(
        self, message: ConnectionProbeMessage
    ) -> None:
        """Process a connection probe message."""
        self.debug(lambda: f"Received connection probe message: {message}")
        self._connection_probe_event.set()

    async def _probe_and_wait_for_response(self) -> None:
        """Wait for a connection probe message."""
        await self.publish(
            ConnectionProbeMessage(service_id=self.id, target_service_id=self.id)
        )
        await self._connection_probe_event.wait()

    async def subscribe(
        self,
        message_type: MessageTypeT,
        callback: Callable[[Message], Coroutine[Any, Any, None]],
    ) -> None:
        """Subscribe to a specific message type. The callback will be called when
        a message is received for the given message type."""
        await self.sub_client.subscribe(message_type, callback)

    async def subscribe_all(
        self,
        message_callback_map: MessageCallbackMapT,
    ) -> None:
        """Subscribe to all message types in the map. The callback(s) will be called when
        a message is received for the given message type.

        Args:
            message_callback_map: A map of message types to callbacks. The callbacks can be a single callback or a list of callbacks.
        """
        await self.sub_client.subscribe_all(message_callback_map)

    async def publish(self, message: Message) -> None:
        """Publish a message. The message will be routed automatically based on the message type."""
        await self.pub_client.publish(message)

publish(message) async

Publish a message. The message will be routed automatically based on the message type.

Source code in aiperf/common/mixins/message_bus_mixin.py
139
140
141
async def publish(self, message: Message) -> None:
    """Publish a message. The message will be routed automatically based on the message type."""
    await self.pub_client.publish(message)

subscribe(message_type, callback) async

Subscribe to a specific message type. The callback will be called when a message is received for the given message type.

Source code in aiperf/common/mixins/message_bus_mixin.py
118
119
120
121
122
123
124
125
async def subscribe(
    self,
    message_type: MessageTypeT,
    callback: Callable[[Message], Coroutine[Any, Any, None]],
) -> None:
    """Subscribe to a specific message type. The callback will be called when
    a message is received for the given message type."""
    await self.sub_client.subscribe(message_type, callback)

subscribe_all(message_callback_map) async

Subscribe to all message types in the map. The callback(s) will be called when a message is received for the given message type.

Parameters:

Name Type Description Default
message_callback_map MessageCallbackMapT

A map of message types to callbacks. The callbacks can be a single callback or a list of callbacks.

required
Source code in aiperf/common/mixins/message_bus_mixin.py
127
128
129
130
131
132
133
134
135
136
137
async def subscribe_all(
    self,
    message_callback_map: MessageCallbackMapT,
) -> None:
    """Subscribe to all message types in the map. The callback(s) will be called when
    a message is received for the given message type.

    Args:
        message_callback_map: A map of message types to callbacks. The callbacks can be a single callback or a list of callbacks.
    """
    await self.sub_client.subscribe_all(message_callback_map)

aiperf.common.mixins.process_health_mixin

ProcessHealthMixin

Bases: BaseMixin

Mixin to provide process health information.

Source code in aiperf/common/mixins/process_health_mixin.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class ProcessHealthMixin(BaseMixin):
    """Mixin to provide process health information."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Initialize process-specific CPU monitoring
        self.process: psutil.Process = psutil.Process()
        self.process.cpu_percent()  # throw away the first result (will be 0)
        self.create_time: float = self.process.create_time()

        self.process_health: ProcessHealth | None = None
        self.previous: ProcessHealth | None = None

    def get_process_health(self) -> ProcessHealth:
        """Get the process health information for the current process."""

        # Get process-specific CPU and memory usage
        raw_cpu_times = self.process.cpu_times()
        cpu_times = CPUTimes(
            user=raw_cpu_times[0],
            system=raw_cpu_times[1],
            iowait=raw_cpu_times[4] if len(raw_cpu_times) > 4 else 0.0,  # type: ignore
        )

        self.previous = self.process_health

        self.process_health = ProcessHealth(
            pid=self.process.pid,
            create_time=self.create_time,
            uptime=time.time() - self.create_time,
            cpu_usage=self.process.cpu_percent(),
            memory_usage=self.process.memory_info().rss / BYTES_PER_MIB,
            io_counters=self.process.io_counters(),
            cpu_times=cpu_times,
            num_ctx_switches=CtxSwitches(*self.process.num_ctx_switches()),
            num_threads=self.process.num_threads(),
        )
        return self.process_health

get_process_health()

Get the process health information for the current process.

Source code in aiperf/common/mixins/process_health_mixin.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def get_process_health(self) -> ProcessHealth:
    """Get the process health information for the current process."""

    # Get process-specific CPU and memory usage
    raw_cpu_times = self.process.cpu_times()
    cpu_times = CPUTimes(
        user=raw_cpu_times[0],
        system=raw_cpu_times[1],
        iowait=raw_cpu_times[4] if len(raw_cpu_times) > 4 else 0.0,  # type: ignore
    )

    self.previous = self.process_health

    self.process_health = ProcessHealth(
        pid=self.process.pid,
        create_time=self.create_time,
        uptime=time.time() - self.create_time,
        cpu_usage=self.process.cpu_percent(),
        memory_usage=self.process.memory_info().rss / BYTES_PER_MIB,
        io_counters=self.process.io_counters(),
        cpu_times=cpu_times,
        num_ctx_switches=CtxSwitches(*self.process.num_ctx_switches()),
        num_threads=self.process.num_threads(),
    )
    return self.process_health

aiperf.common.mixins.pull_client_mixin

PullClientMixin

Bases: CommunicationMixin, ABC

Mixin to provide a pull client for AIPerf components using a PullClient for the specified CommAddress. Add the @on_pull_message decorator to specify a function that will be called when a pull is received.

NOTE: This currently only supports a single pull client per service, as that is our current use case.

Source code in aiperf/common/mixins/pull_client_mixin.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@provides_hooks(AIPerfHook.ON_PULL_MESSAGE)
class PullClientMixin(CommunicationMixin, ABC):
    """Mixin to provide a pull client for AIPerf components using a PullClient for the specified CommAddress.
    Add the @on_pull_message decorator to specify a function that will be called when a pull is received.

    NOTE: This currently only supports a single pull client per service, as that is our current use case.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        pull_client_address: CommAddress,
        pull_client_bind: bool = False,
        max_pull_concurrency: int | None = None,
        **kwargs,
    ) -> None:
        super().__init__(service_config=service_config, **kwargs)
        # NOTE: The communication base class will automatically manage the pull client's lifecycle.
        self.pull_client = self.comms.create_pull_client(
            pull_client_address,
            bind=pull_client_bind,
            max_pull_concurrency=max_pull_concurrency,
        )

    @on_init
    async def _setup_pull_handler_hooks(self) -> None:
        """Configure the pull client to register callbacks for all @on_pull_message hook decorators."""

        def _register_pull_callback(hook: Hook, message_type: MessageTypeT) -> None:
            self.debug(
                lambda: f"Registering pull callback for message type: {message_type} for hook: {hook}"
            )
            self.pull_client.register_pull_callback(
                message_type=message_type,
                callback=hook.func,
            )

        # For each @on_pull_message hook, register a pull callback for each specified message type.
        self.for_each_hook_param(
            AIPerfHook.ON_PULL_MESSAGE,
            self_obj=self,
            param_type=MessageTypeT,
            lambda_func=_register_pull_callback,
        )

aiperf.common.mixins.reply_client_mixin

ReplyClientMixin

Bases: CommunicationMixin, ABC

Mixin to provide a reply client for AIPerf components using a ReplyClient for the specified CommAddress. Add the @on_request decorator to specify a function that will be called when a request is received.

NOTE: This currently only supports a single reply client per service, as that is our current use case.

Source code in aiperf/common/mixins/reply_client_mixin.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@provides_hooks(AIPerfHook.ON_REQUEST)
class ReplyClientMixin(CommunicationMixin, ABC):
    """Mixin to provide a reply client for AIPerf components using a ReplyClient for the specified CommAddress.
    Add the @on_request decorator to specify a function that will be called when a request is received.

    NOTE: This currently only supports a single reply client per service, as that is our current use case.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        reply_client_address: CommAddress,
        reply_client_bind: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(service_config=service_config, **kwargs)
        # NOTE: The communication base class will automatically manage the reply client's lifecycle.
        self.reply_client = self.comms.create_reply_client(
            reply_client_address, bind=reply_client_bind
        )

    @on_init
    async def _setup_request_handler_hooks(self) -> None:
        """Configure the reply client to handle requests for all @request_handler hook decorators."""

        def _register_request_handler(hook: Hook, message_type: MessageTypeT) -> None:
            self.debug(
                lambda: f"Registering request handler for message type: {message_type} for hook: {hook}"
            )
            self.reply_client.register_request_handler(
                service_id=self.id,
                message_type=message_type,
                handler=hook.func,
            )

        # For each @on_request hook, register a request handler for each message type.
        self.for_each_hook_param(
            AIPerfHook.ON_REQUEST,
            self_obj=self,
            param_type=MessageTypeT,
            lambda_func=_register_request_handler,
        )

aiperf.common.mixins.task_manager_mixin

TaskManagerMixin

Bases: AIPerfLoggerMixin

Mixin to manage a set of async tasks, and provide background task loop capabilities. Can be used standalone, but it is most useful as part of the :class:AIPerfLifecycleMixin mixin, where the lifecycle methods are automatically integrated with the task manager.

Source code in aiperf/common/mixins/task_manager_mixin.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@implements_protocol(TaskManagerProtocol)
class TaskManagerMixin(AIPerfLoggerMixin):
    """Mixin to manage a set of async tasks, and provide background task loop capabilities.
    Can be used standalone, but it is most useful as part of the :class:`AIPerfLifecycleMixin`
    mixin, where the lifecycle methods are automatically integrated with the task manager.
    """

    def __init__(self, **kwargs):
        self.tasks: set[asyncio.Task] = set()
        super().__init__(**kwargs)

    def execute_async(self, coro: Coroutine) -> asyncio.Task:
        """Create a task from a coroutine and add it to the set of tasks, and return immediately.
        The task will be automatically cleaned up when it completes.
        """
        task = asyncio.create_task(coro)
        self.tasks.add(task)
        task.add_done_callback(self.tasks.discard)
        return task

    async def wait_for_tasks(self) -> list[BaseException | None]:
        """Wait for all current tasks to complete."""
        return await asyncio.gather(*list(self.tasks), return_exceptions=True)

    async def cancel_all_tasks(
        self, timeout: float = TASK_CANCEL_TIMEOUT_SHORT
    ) -> None:
        """Cancel all tasks in the set and wait for up to timeout seconds for them to complete.

        Args:
            timeout: The timeout to wait for the tasks to complete.
        """
        if not self.tasks:
            return

        task_list = list(self.tasks)
        for task in task_list:
            task.cancel()

    def start_background_task(
        self,
        method: Callable,
        interval: float | Callable[[TaskManagerProtocol], float] | None = None,
        immediate: bool = False,
        stop_on_error: bool = False,
        stop_event: asyncio.Event | None = None,
    ) -> None:
        """Run a task in the background, in a loop until cancelled."""
        self.execute_async(
            self._background_task_loop(
                method, interval, immediate, stop_on_error, stop_event
            )
        )

    async def _background_task_loop(
        self,
        method: Callable,
        interval: float | Callable[[TaskManagerProtocol], float] | None = None,
        immediate: bool = False,
        stop_on_error: bool = False,
        stop_event: asyncio.Event | None = None,
    ) -> None:
        """Run a background task in a loop until cancelled.

        Args:
            method: The method to run as a background task.
            interval: The interval to run the task in seconds. Can be a callable that returns the interval, and will be called with 'self' as the argument.
            immediate: If True, run the task immediately on start, otherwise wait for the interval first.
            stop_on_error: If True, stop the task on any exception, otherwise log and continue.
        """
        while stop_event is None or not stop_event.is_set():
            try:
                if interval is None or immediate:
                    await yield_to_event_loop()
                    # Reset immediate flag for next iteration otherwise we will not sleep
                    immediate = False
                else:
                    sleep_time = interval(self) if callable(interval) else interval
                    await asyncio.sleep(sleep_time)

                if inspect.iscoroutinefunction(method):
                    await method()
                else:
                    await asyncio.to_thread(method)

                if interval is None:
                    break
            except asyncio.CancelledError:
                self.debug(f"Background task {method.__name__} cancelled")
                break
            except Exception as e:
                self.exception(f"Error in background task {method.__name__}: {e}")
                if stop_on_error:
                    self.exception(
                        f"Background task {method.__name__} stopped due to error"
                    )
                    break
                # Give some time to recover, just in case
                await asyncio.sleep(0.001)

cancel_all_tasks(timeout=TASK_CANCEL_TIMEOUT_SHORT) async

Cancel all tasks in the set and wait for up to timeout seconds for them to complete.

Parameters:

Name Type Description Default
timeout float

The timeout to wait for the tasks to complete.

TASK_CANCEL_TIMEOUT_SHORT
Source code in aiperf/common/mixins/task_manager_mixin.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
async def cancel_all_tasks(
    self, timeout: float = TASK_CANCEL_TIMEOUT_SHORT
) -> None:
    """Cancel all tasks in the set and wait for up to timeout seconds for them to complete.

    Args:
        timeout: The timeout to wait for the tasks to complete.
    """
    if not self.tasks:
        return

    task_list = list(self.tasks)
    for task in task_list:
        task.cancel()

execute_async(coro)

Create a task from a coroutine and add it to the set of tasks, and return immediately. The task will be automatically cleaned up when it completes.

Source code in aiperf/common/mixins/task_manager_mixin.py
25
26
27
28
29
30
31
32
def execute_async(self, coro: Coroutine) -> asyncio.Task:
    """Create a task from a coroutine and add it to the set of tasks, and return immediately.
    The task will be automatically cleaned up when it completes.
    """
    task = asyncio.create_task(coro)
    self.tasks.add(task)
    task.add_done_callback(self.tasks.discard)
    return task

start_background_task(method, interval=None, immediate=False, stop_on_error=False, stop_event=None)

Run a task in the background, in a loop until cancelled.

Source code in aiperf/common/mixins/task_manager_mixin.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def start_background_task(
    self,
    method: Callable,
    interval: float | Callable[[TaskManagerProtocol], float] | None = None,
    immediate: bool = False,
    stop_on_error: bool = False,
    stop_event: asyncio.Event | None = None,
) -> None:
    """Run a task in the background, in a loop until cancelled."""
    self.execute_async(
        self._background_task_loop(
            method, interval, immediate, stop_on_error, stop_event
        )
    )

wait_for_tasks() async

Wait for all current tasks to complete.

Source code in aiperf/common/mixins/task_manager_mixin.py
34
35
36
async def wait_for_tasks(self) -> list[BaseException | None]:
    """Wait for all current tasks to complete."""
    return await asyncio.gather(*list(self.tasks), return_exceptions=True)

aiperf.common.models.base_models

AIPerfBaseModel

Bases: BaseModel

Base model for all AIPerf Pydantic models. This class is configured to allow arbitrary types to be used as fields as to allow for more flexible model definitions by end users without breaking the existing code.

The @exclude_if_none decorator can also be used to specify which fields should be excluded from the serialized model if they are None. This is a workaround for the fact that pydantic does not support specifying exclude_none on a per-field basis.

Source code in aiperf/common/models/base_models.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class AIPerfBaseModel(BaseModel):
    """Base model for all AIPerf Pydantic models. This class is configured to allow
    arbitrary types to be used as fields as to allow for more flexible model definitions
    by end users without breaking the existing code.

    The @exclude_if_none decorator can also be used to specify which fields
    should be excluded from the serialized model if they are None. This is a workaround
    for the fact that pydantic does not support specifying exclude_none on a per-field basis.
    """

    _exclude_if_none_fields: ClassVar[set[str]] = set()
    """Set of field names that should be excluded from the serialized model if they
    are None. This is set by the @exclude_if_none decorator.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)

    @model_serializer
    def _serialize_model(self) -> dict[str, Any]:
        """Serialize the model to a dictionary.

        This method overrides the default serializer to exclude fields that with a
        value of None and were marked with the @exclude_if_none decorator.
        """
        return {
            k: v
            for k, v in self
            if not (k in self._exclude_if_none_fields and v is None)
        }

exclude_if_none(*field_names)

Decorator to set the _exclude_if_none_fields class attribute to the set of field names that should be excluded if they are None.

Source code in aiperf/common/models/base_models.py
10
11
12
13
14
15
16
17
18
19
20
21
22
def exclude_if_none(*field_names: str):
    """Decorator to set the _exclude_if_none_fields class attribute to the set of
    field names that should be excluded if they are None.
    """

    def decorator(model: type[AIPerfBaseModelT]) -> type[AIPerfBaseModelT]:
        # This attribute is defined by the AIPerfBaseModel class.
        if not hasattr(model, "_exclude_if_none_fields"):
            model._exclude_if_none_fields = set()
        model._exclude_if_none_fields.update(set(field_names))
        return model

    return decorator

aiperf.common.models.credit_models

CreditPhaseConfig

Bases: AIPerfBaseModel

Model for phase credit config. This is used by the TimingManager to configure the credit phases.

Source code in aiperf/common/models/credit_models.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class CreditPhaseConfig(AIPerfBaseModel):
    """Model for phase credit config. This is used by the TimingManager to configure the credit phases."""

    type: CreditPhase = Field(..., description="The type of credit phase")
    total_expected_requests: int | None = Field(
        default=None,
        ge=1,
        description="The total number of expected credits. If None, the phase is not request count based.",
    )
    expected_duration_sec: float | None = Field(
        default=None,
        ge=1,
        description="The expected duration of the credit phase in seconds. If None, the phase is not time based.",
    )

    @property
    def is_time_based(self) -> bool:
        return self.expected_duration_sec is not None

    @property
    def is_request_count_based(self) -> bool:
        return self.total_expected_requests is not None

    @property
    def is_valid(self) -> bool:
        """A phase config is valid if it is exactly one of the following:
        - is_time_based (expected_duration_sec is set and > 0)
        - is_request_count_based (total_expected_requests is set and > 0)
        """
        is_time_based = self.is_time_based
        is_request_count_based = self.is_request_count_based
        return (is_time_based and not is_request_count_based) or (
            not is_time_based and is_request_count_based
        )

is_valid property

A phase config is valid if it is exactly one of the following: - is_time_based (expected_duration_sec is set and > 0) - is_request_count_based (total_expected_requests is set and > 0)

CreditPhaseStats

Bases: CreditPhaseConfig

Model for phase credit stats. Extends the CreditPhaseConfig fields to track the progress of the credit phases. How many credits were dropped and how many were returned, as well as the progress percentage of the phase.

Source code in aiperf/common/models/credit_models.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class CreditPhaseStats(CreditPhaseConfig):
    """Model for phase credit stats. Extends the CreditPhaseConfig fields to track the progress of the credit phases.
    How many credits were dropped and how many were returned, as well as the progress percentage of the phase."""

    start_ns: int | None = Field(
        default=None,
        description="The start time of the credit phase in nanoseconds.",
    )
    sent_end_ns: int | None = Field(
        default=None,
        description="The time of the last sent credit in nanoseconds. If None, the phase has not sent all credits.",
    )
    end_ns: int | None = Field(
        default=None,
        ge=1,
        description="The time in which the last credit was returned from the workers in nanoseconds. If None, the phase has not completed.",
    )
    sent: int = Field(default=0, description="The number of sent credits")
    completed: int = Field(
        default=0,
        description="The number of completed credits (returned from the workers)",
    )

    @property
    def is_sending_complete(self) -> bool:
        return self.sent_end_ns is not None

    @property
    def is_complete(self) -> bool:
        return self.is_sending_complete and self.end_ns is not None

    @property
    def is_started(self) -> bool:
        return self.start_ns is not None

    @property
    def in_flight(self) -> int:
        """Calculate the number of in-flight credits (sent but not completed)."""
        return self.sent - self.completed

    @property
    def should_send(self) -> bool:
        """Whether the phase should send more credits."""
        if self.is_time_based:
            return (
                time.time_ns() - (self.start_ns or 0)
                <= (self.expected_duration_sec * NANOS_PER_SECOND)  # type: ignore
            )
        elif self.is_request_count_based:
            return self.sent < self.total_expected_requests  # type: ignore
        raise InvalidStateError("Phase is not time or request count based")

    @property
    def progress_percent(self) -> float | None:
        if self.start_ns is None:
            return None

        if self.is_complete:
            return 100

        if self.is_time_based:
            # Time based, so progress is the percentage of time elapsed compared to the duration

            return (
                (time.time_ns() - self.start_ns)
                / (self.expected_duration_sec * NANOS_PER_SECOND)  # type: ignore
            ) * 100

        elif self.total_expected_requests is not None:
            # Credit count based, so progress is the percentage of credits returned
            return (self.completed / self.total_expected_requests) * 100

        # We don't know the progress
        return None

    @classmethod
    def from_phase_config(cls, phase_config: CreditPhaseConfig) -> "CreditPhaseStats":
        """Create a CreditPhaseStats from a CreditPhaseConfig. This is used to initialize the stats for a phase."""
        return cls(
            type=phase_config.type,
            total_expected_requests=phase_config.total_expected_requests,
            expected_duration_sec=phase_config.expected_duration_sec,
        )

in_flight property

Calculate the number of in-flight credits (sent but not completed).

should_send property

Whether the phase should send more credits.

from_phase_config(phase_config) classmethod

Create a CreditPhaseStats from a CreditPhaseConfig. This is used to initialize the stats for a phase.

Source code in aiperf/common/models/credit_models.py
125
126
127
128
129
130
131
132
@classmethod
def from_phase_config(cls, phase_config: CreditPhaseConfig) -> "CreditPhaseStats":
    """Create a CreditPhaseStats from a CreditPhaseConfig. This is used to initialize the stats for a phase."""
    return cls(
        type=phase_config.type,
        total_expected_requests=phase_config.total_expected_requests,
        expected_duration_sec=phase_config.expected_duration_sec,
    )

PhaseProcessingStats

Bases: AIPerfBaseModel

Model for phase processing stats. How many requests were processed and how many errors were encountered.

Source code in aiperf/common/models/credit_models.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class PhaseProcessingStats(AIPerfBaseModel):
    """Model for phase processing stats. How many requests were processed and
    how many errors were encountered."""

    processed: int = Field(
        default=0, description="The number of records processed successfully"
    )
    errors: int = Field(
        default=0, description="The number of record errors encountered"
    )
    total_expected_requests: int | None = Field(
        default=None,
        description="The total number of expected requests to process. If None, the phase is not request count based.",
    )

    @property
    def total_records(self) -> int:
        """The total number of records processed successfully or in error."""
        return self.processed + self.errors

total_records property

The total number of records processed successfully or in error.

aiperf.common.models.dataset_models

Audio

Bases: Media

Media that contains audio data.

Source code in aiperf/common/models/dataset_models.py
36
37
38
39
class Audio(Media):
    """Media that contains audio data."""

    media_type: ClassVar[MediaTypeT] = MediaType.AUDIO

Conversation

Bases: AIPerfBaseModel

A dataset representation of a full conversation.

A conversation is a sequence of turns between a user and an endpoint, and it contains the session ID and all the turns that consists the conversation.

Source code in aiperf/common/models/dataset_models.py
70
71
72
73
74
75
76
77
78
79
80
class Conversation(AIPerfBaseModel):
    """A dataset representation of a full conversation.

    A conversation is a sequence of turns between a user and an endpoint,
    and it contains the session ID and all the turns that consists the conversation.
    """

    turns: list[Turn] = Field(
        default=[], description="List of turns in the conversation."
    )
    session_id: str = Field(default="", description="Session ID of the conversation.")

Image

Bases: Media

Media that contains image data.

Source code in aiperf/common/models/dataset_models.py
30
31
32
33
class Image(Media):
    """Media that contains image data."""

    media_type: ClassVar[MediaTypeT] = MediaType.IMAGE

Media

Bases: AIPerfBaseModel

Base class for all media fields. Contains name and contents of the media data.

Source code in aiperf/common/models/dataset_models.py
13
14
15
16
17
18
19
20
21
class Media(AIPerfBaseModel):
    """Base class for all media fields. Contains name and contents of the media data."""

    name: str = Field(default="", description="Name of the media field.")

    contents: list[str] = Field(
        default=[],
        description="List of media contents. Supports batched media payload in a single turn.",
    )

Text

Bases: Media

Media that contains text/prompt data.

Source code in aiperf/common/models/dataset_models.py
24
25
26
27
class Text(Media):
    """Media that contains text/prompt data."""

    media_type: ClassVar[MediaTypeT] = MediaType.TEXT

Turn

Bases: AIPerfBaseModel

A dataset representation of a single turn within a conversation.

A turn is a single interaction between a user and an AI assistant, and it contains timestamp, delay, and raw data that user sends in each turn.

Source code in aiperf/common/models/dataset_models.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@exclude_if_none("role")
class Turn(AIPerfBaseModel):
    """A dataset representation of a single turn within a conversation.

    A turn is a single interaction between a user and an AI assistant,
    and it contains timestamp, delay, and raw data that user sends in each turn.
    """

    timestamp: int | None = Field(
        default=None, description="Timestamp of the turn in milliseconds."
    )
    delay: int | None = Field(
        default=None,
        description="Amount of milliseconds to wait before sending the turn.",
    )
    model: str | None = Field(default=None, description="Model name used for the turn.")
    role: str | None = Field(default=None, description="Role of the turn.")
    texts: list[Text] = Field(
        default=[], description="Collection of text data in each turn."
    )
    images: list[Image] = Field(
        default=[], description="Collection of image data in each turn."
    )
    audios: list[Audio] = Field(
        default=[], description="Collection of audio data in each turn."
    )

aiperf.common.models.error_models

ErrorDetails

Bases: AIPerfBaseModel

Encapsulates details about an error.

Source code in aiperf/common/models/error_models.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class ErrorDetails(AIPerfBaseModel):
    """Encapsulates details about an error."""

    code: int | None = Field(
        default=None,
        description="The error code.",
    )
    type: str | None = Field(
        default=None,
        description="The error type.",
    )
    message: str = Field(
        ...,
        description="The error message.",
    )

    def __eq__(self, other: Any) -> bool:
        """Check if the error details are equal by comparing the code, type, and message."""
        if not isinstance(other, ErrorDetails):
            return False
        return (
            self.code == other.code
            and self.type == other.type
            and self.message == other.message
        )

    def __hash__(self) -> int:
        """Hash the error details by hashing the code, type, and message."""
        return hash((self.code, self.type, self.message))

    @classmethod
    def from_exception(cls, e: BaseException) -> "ErrorDetails":
        """Create an error details object from an exception."""
        return cls(
            type=e.__class__.__name__,
            message=str(e),
        )

__eq__(other)

Check if the error details are equal by comparing the code, type, and message.

Source code in aiperf/common/models/error_models.py
26
27
28
29
30
31
32
33
34
def __eq__(self, other: Any) -> bool:
    """Check if the error details are equal by comparing the code, type, and message."""
    if not isinstance(other, ErrorDetails):
        return False
    return (
        self.code == other.code
        and self.type == other.type
        and self.message == other.message
    )

__hash__()

Hash the error details by hashing the code, type, and message.

Source code in aiperf/common/models/error_models.py
36
37
38
def __hash__(self) -> int:
    """Hash the error details by hashing the code, type, and message."""
    return hash((self.code, self.type, self.message))

from_exception(e) classmethod

Create an error details object from an exception.

Source code in aiperf/common/models/error_models.py
40
41
42
43
44
45
46
@classmethod
def from_exception(cls, e: BaseException) -> "ErrorDetails":
    """Create an error details object from an exception."""
    return cls(
        type=e.__class__.__name__,
        message=str(e),
    )

ErrorDetailsCount

Bases: AIPerfBaseModel

Count of error details.

Source code in aiperf/common/models/error_models.py
49
50
51
52
53
54
55
56
class ErrorDetailsCount(AIPerfBaseModel):
    """Count of error details."""

    error_details: ErrorDetails
    count: int = Field(
        ...,
        description="The count of the error details.",
    )

aiperf.common.models.health_models

ProcessHealth

Bases: AIPerfBaseModel

Model for process health data.

Source code in aiperf/common/models/health_models.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class ProcessHealth(AIPerfBaseModel):
    """Model for process health data."""

    pid: int | None = Field(
        default=None,
        description="The PID of the process",
    )
    create_time: float = Field(
        ..., description="The creation time of the process in seconds"
    )
    uptime: float = Field(..., description="The uptime of the process in seconds")
    cpu_usage: float = Field(
        ..., description="The current CPU usage of the process in %"
    )
    memory_usage: float = Field(
        ..., description="The current memory usage of the process in MiB (rss)"
    )
    io_counters: IOCounters | tuple | None = Field(
        default=None,
        description="The current I/O counters of the process (read_count, write_count, read_bytes, write_bytes, read_chars, write_chars)",
    )
    cpu_times: CPUTimes | tuple | None = Field(
        default=None,
        description="The current CPU times of the process (user, system, iowait)",
    )
    num_ctx_switches: CtxSwitches | tuple | None = Field(
        default=None,
        description="The current number of context switches (voluntary, involuntary)",
    )
    num_threads: int | None = Field(
        default=None,
        description="The current number of threads",
    )

aiperf.common.models.record_models

InferenceServerResponse

Bases: AIPerfBaseModel

Response from a inference client.

Source code in aiperf/common/models/record_models.py
91
92
93
94
95
96
97
class InferenceServerResponse(AIPerfBaseModel):
    """Response from a inference client."""

    perf_ns: int = Field(
        ...,
        description="The timestamp of the response in nanoseconds (perf_counter_ns).",
    )

MetricResult

Bases: AIPerfBaseModel

The result values of a single metric.

Source code in aiperf/common/models/record_models.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class MetricResult(AIPerfBaseModel):
    """The result values of a single metric."""

    tag: MetricTagT = Field(description="The unique identifier of the metric")
    # NOTE: We do not use a MetricUnitT here, as that is harder to de-serialize from JSON strings with pydantic.
    #       If we need an instance of a MetricUnitT, lookup the unit based on the tag in the MetricRegistry.
    unit: str = Field(description="The unit of the metric, e.g. 'ms'")
    header: str = Field(
        description="The user friendly name of the metric (e.g. 'Inter Token Latency')"
    )
    avg: float | None = None
    min: int | float | None = None
    max: int | float | None = None
    p1: float | None = None
    p5: float | None = None
    p25: float | None = None
    p50: float | None = None
    p75: float | None = None
    p90: float | None = None
    p95: float | None = None
    p99: float | None = None
    std: float | None = None
    count: int | None = Field(
        default=None,
        description="The total number of records used to calculate the metric",
    )

ParsedResponseRecord

Bases: AIPerfBaseModel

Record of a request and its associated responses, already parsed and ready for metrics.

Source code in aiperf/common/models/record_models.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
class ParsedResponseRecord(AIPerfBaseModel):
    """Record of a request and its associated responses, already parsed and ready for metrics."""

    request: RequestRecord = Field(description="The original request record")
    responses: list[ResponseData] = Field(description="The parsed response data.")
    input_token_count: int | None = Field(
        default=None,
        description="The number of tokens in the input. If None, the number of tokens could not be calculated.",
    )
    output_token_count: int | None = Field(
        default=None,
        description="The number of tokens across all responses. If None, the number of tokens could not be calculated.",
    )

    @cached_property
    def start_perf_ns(self) -> int:
        """Get the start time of the request in nanoseconds (perf_counter_ns)."""
        return self.request.start_perf_ns

    @cached_property
    def timestamp_ns(self) -> int:
        """Get the wall clock timestamp of the request in nanoseconds. DO NOT USE FOR LATENCY CALCULATIONS. (time.time_ns)."""
        return self.request.timestamp_ns

    # TODO: How do we differentiate the end of the request vs the time of the last response?
    #       Which one should we use for the latency metrics?
    @cached_property
    def end_perf_ns(self) -> int:
        """Get the end time of the request in nanoseconds (perf_counter_ns).
        If request.end_perf_ns is not set, use the time of the last response.
        If there are no responses, use sys.maxsize.
        """
        return (
            self.request.end_perf_ns
            if self.request.end_perf_ns
            else self.responses[-1].perf_ns
            if self.responses
            else sys.maxsize
        )

    @cached_property
    def request_duration_ns(self) -> int:
        """Get the duration of the request in nanoseconds."""
        return self.end_perf_ns - self.start_perf_ns

    @cached_property
    def tokens_per_second(self) -> float | None:
        """Get the number of tokens per second of the request."""
        if self.output_token_count is None or self.request_duration_ns == 0:
            return None
        return self.output_token_count / (self.request_duration_ns / NANOS_PER_SECOND)

    @cached_property
    def has_error(self) -> bool:
        """Check if the response record has an error."""
        return self.request.has_error

    @cached_property
    def valid(self) -> bool:
        """Check if the response record is valid.

        Checks:
        - Request has no errors
        - Has at least one response
        - Start time is before the end time
        - Response timestamps are within valid ranges

        Returns:
            bool: True if the record is valid, False otherwise.
        """
        return (
            not self.has_error
            and len(self.responses) > 0
            and 0 <= self.start_perf_ns < self.end_perf_ns < sys.maxsize
            and all(0 < response.perf_ns < sys.maxsize for response in self.responses)
        )

end_perf_ns cached property

Get the end time of the request in nanoseconds (perf_counter_ns). If request.end_perf_ns is not set, use the time of the last response. If there are no responses, use sys.maxsize.

has_error cached property

Check if the response record has an error.

request_duration_ns cached property

Get the duration of the request in nanoseconds.

start_perf_ns cached property

Get the start time of the request in nanoseconds (perf_counter_ns).

timestamp_ns cached property

Get the wall clock timestamp of the request in nanoseconds. DO NOT USE FOR LATENCY CALCULATIONS. (time.time_ns).

tokens_per_second cached property

Get the number of tokens per second of the request.

valid cached property

Check if the response record is valid.

Checks: - Request has no errors - Has at least one response - Start time is before the end time - Response timestamps are within valid ranges

Returns:

Name Type Description
bool bool

True if the record is valid, False otherwise.

ProcessRecordsResult

Bases: AIPerfBaseModel

Result of the process records command.

Source code in aiperf/common/models/record_models.py
76
77
78
79
80
81
82
83
class ProcessRecordsResult(AIPerfBaseModel):
    """Result of the process records command."""

    results: ProfileResults = Field(..., description="The profile results")
    errors: list[ErrorDetails] = Field(
        default_factory=list,
        description="Any error that occurred while processing the profile results",
    )

RequestRecord

Bases: AIPerfBaseModel

Record of a request with its associated responses.

Source code in aiperf/common/models/record_models.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
class RequestRecord(AIPerfBaseModel):
    """Record of a request with its associated responses."""

    request: Any | None = Field(
        default=None,
        description="The request payload formatted for the inference API.",
    )
    conversation_id: str | None = Field(
        default=None,
        description="The ID of the conversation (if applicable).",
    )
    turn_index: int | None = Field(
        default=None,
        ge=0,
        description="The index of the turn in the conversation (if applicable).",
    )
    model_name: str | None = Field(
        default=None,
        description="The name of the model targeted by the request.",
    )
    timestamp_ns: int = Field(
        default_factory=time.time_ns,
        description="The wall clock timestamp of the request in nanoseconds. DO NOT USE FOR LATENCY CALCULATIONS. (time.time_ns).",
    )
    start_perf_ns: int = Field(
        default_factory=time.perf_counter_ns,
        description="The start reference time of the request in nanoseconds used for latency calculations (perf_counter_ns).",
    )
    end_perf_ns: int | None = Field(
        default=None,
        description="The end time of the request in nanoseconds (perf_counter_ns).",
    )
    recv_start_perf_ns: int | None = Field(
        default=None,
        description="The start time of the streaming response in nanoseconds (perf_counter_ns).",
    )
    status: int | None = Field(
        default=None,
        description="The HTTP status code of the response.",
    )
    # TODO: Maybe we could improve this with subclassing the responses to allow for more specific types.
    #       This would allow us to remove the SerializeAsAny and use a more specific type. Look at how we handle
    #       the CommandMessage and CommandResponse classes for an example.
    # NOTE: We need to use SerializeAsAny to allow for generic subclass support
    # NOTE: The order of the types is important, as that is the order they are type checked.
    #       Start with the most specific types and work towards the most general types.
    responses: SerializeAsAny[
        list[SSEMessage | TextResponse | InferenceServerResponse | Any]
    ] = Field(
        default_factory=list,
        description="The raw responses received from the request.",
    )
    error: ErrorDetails | None = Field(
        default=None,
        description="The error details if the request failed.",
    )
    delayed_ns: int | None = Field(
        default=None,
        ge=0,
        description="The number of nanoseconds the request was delayed from when it was expected to be sent, "
        "or None if the request was sent on time, or did not have a credit_drop_ns timestamp.",
    )
    credit_phase: CreditPhase = Field(
        default=CreditPhase.PROFILING,
        description="The type of credit phase (either warmup or profiling)",
    )

    @property
    def delayed(self) -> bool:
        """Check if the request was delayed."""
        return self.delayed_ns is not None and self.delayed_ns > 0

    # TODO: Most of these properties will be removed once we have proper record handling and metrics.

    @property
    def has_error(self) -> bool:
        """Check if the request record has an error."""
        return self.error is not None

    @property
    def valid(self) -> bool:
        """Check if the request record is valid by ensuring that the start time
        and response timestamps are within valid ranges.

        Returns:
            bool: True if the record is valid, False otherwise.
        """
        return not self.has_error and (
            0 <= self.start_perf_ns < sys.maxsize
            and len(self.responses) > 0
            and all(0 < response.perf_ns < sys.maxsize for response in self.responses)
        )

    @property
    def time_to_first_response_ns(self) -> int | None:
        """Get the time to the first response in nanoseconds."""
        if not self.valid:
            return None
        return (
            self.responses[0].perf_ns - self.start_perf_ns
            if self.start_perf_ns
            else None
        )

    @property
    def time_to_second_response_ns(self) -> int | None:
        """Get the time to the second response in nanoseconds."""
        if not self.valid or len(self.responses) < 2:
            return None
        return (
            self.responses[1].perf_ns - self.responses[0].perf_ns
            if self.responses[1].perf_ns and self.responses[0].perf_ns
            else None
        )

    @property
    def time_to_last_response_ns(self) -> int | None:
        """Get the time to the last response in nanoseconds."""
        if not self.valid:
            return None
        if self.end_perf_ns is None or self.start_perf_ns is None:
            return None
        return self.end_perf_ns - self.start_perf_ns if self.start_perf_ns else None

    @property
    def inter_token_latency_ns(self) -> float | None:
        """Get the interval between responses in nanoseconds."""
        if not self.valid or len(self.responses) < 2:
            return None

        if (
            isinstance(self.responses[-1], SSEMessage)
            and self.responses[-1].packets[-1].value == "[DONE]"
        ):
            return (
                (self.responses[-2].perf_ns - self.responses[0].perf_ns)
                / (len(self.responses) - 2)
                if self.responses[-2].perf_ns and self.responses[0].perf_ns
                else None
            )

        return (
            (self.responses[-1].perf_ns - self.responses[0].perf_ns)
            / (len(self.responses) - 1)
            if self.responses[-1].perf_ns and self.responses[0].perf_ns
            else None
        )

    def token_latency_ns(self, index: int) -> float | None:
        """Get the latency of a token in nanoseconds."""
        if not self.valid or len(self.responses) < 1:
            return None
        if index == 0:
            return (
                self.responses[0].perf_ns - self.recv_start_perf_ns
                if self.recv_start_perf_ns
                else None
            )
        return (
            self.responses[index].perf_ns - self.responses[index - 1].perf_ns
            if self.responses[index].perf_ns and self.responses[index - 1].perf_ns
            else None
        )

delayed property

Check if the request was delayed.

has_error property

Check if the request record has an error.

inter_token_latency_ns property

Get the interval between responses in nanoseconds.

time_to_first_response_ns property

Get the time to the first response in nanoseconds.

time_to_last_response_ns property

Get the time to the last response in nanoseconds.

time_to_second_response_ns property

Get the time to the second response in nanoseconds.

valid property

Check if the request record is valid by ensuring that the start time and response timestamps are within valid ranges.

Returns:

Name Type Description
bool bool

True if the record is valid, False otherwise.

token_latency_ns(index)

Get the latency of a token in nanoseconds.

Source code in aiperf/common/models/record_models.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def token_latency_ns(self, index: int) -> float | None:
    """Get the latency of a token in nanoseconds."""
    if not self.valid or len(self.responses) < 1:
        return None
    if index == 0:
        return (
            self.responses[0].perf_ns - self.recv_start_perf_ns
            if self.recv_start_perf_ns
            else None
        )
    return (
        self.responses[index].perf_ns - self.responses[index - 1].perf_ns
        if self.responses[index].perf_ns and self.responses[index - 1].perf_ns
        else None
    )

ResponseData

Bases: AIPerfBaseModel

Base class for all response data.

Source code in aiperf/common/models/record_models.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
class ResponseData(AIPerfBaseModel):
    """Base class for all response data."""

    perf_ns: int = Field(description="The performance timestamp of the response.")
    raw_text: list[str] = Field(description="The raw text of the response.")
    parsed_text: list[str | None] = Field(
        description="The parsed text of the response."
    )
    token_count: int | None = Field(
        default=None,
        description="The total number of tokens in the response from the parsed text.",
    )
    metadata: dict[str, Any] = Field(
        default_factory=dict, description="The metadata of the response."
    )

SSEField

Bases: AIPerfBaseModel

Base model for a single field in an SSE message.

Source code in aiperf/common/models/record_models.py
113
114
115
116
117
118
119
120
121
122
123
class SSEField(AIPerfBaseModel):
    """Base model for a single field in an SSE message."""

    name: SSEFieldType | str = Field(
        ...,
        description="The name of the field. e.g. 'data', 'event', 'id', 'retry', 'comment'.",
    )
    value: str | None = Field(
        default=None,
        description="The value of the field.",
    )

SSEMessage

Bases: InferenceServerResponse

Individual SSE message from an SSE stream. Delimited by

.

Source code in aiperf/common/models/record_models.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class SSEMessage(InferenceServerResponse):
    """Individual SSE message from an SSE stream. Delimited by \n\n."""

    # Note: "fields" is a restricted keyword in pydantic
    packets: list[SSEField] = Field(
        default_factory=list,
        description="The fields contained in the message.",
    )

    def extract_data_content(self) -> list[str]:
        """Extract the data contents from the SSE message as a list of strings. Note that the SSE spec specifies
        that each data content should be combined and delimited by a single \n. We have left
        it as a list to allow the caller to decide how to handle the data.

        Returns:
            list[str]: A list of strings containing the data contents of the SSE message.
        """
        return [
            packet.value
            for packet in self.packets
            if packet.name == SSEFieldType.DATA and packet.value is not None
        ]

extract_data_content()

Extract the data contents from the SSE message as a list of strings. Note that the SSE spec specifies that each data content should be combined and delimited by a single . We have left it as a list to allow the caller to decide how to handle the data.

    Returns:
        list[str]: A list of strings containing the data contents of the SSE message.
Source code in aiperf/common/models/record_models.py
135
136
137
138
139
140
141
142
143
144
145
146
147
def extract_data_content(self) -> list[str]:
    """Extract the data contents from the SSE message as a list of strings. Note that the SSE spec specifies
    that each data content should be combined and delimited by a single \n. We have left
    it as a list to allow the caller to decide how to handle the data.

    Returns:
        list[str]: A list of strings containing the data contents of the SSE message.
    """
    return [
        packet.value
        for packet in self.packets
        if packet.name == SSEFieldType.DATA and packet.value is not None
    ]

TextResponse

Bases: InferenceServerResponse

Raw text response from a inference client including an optional content type.

Source code in aiperf/common/models/record_models.py
100
101
102
103
104
105
106
107
108
109
110
class TextResponse(InferenceServerResponse):
    """Raw text response from a inference client including an optional content type."""

    content_type: str | None = Field(
        default=None,
        description="The content type of the response. e.g. 'text/plain', 'application/json'.",
    )
    text: str = Field(
        ...,
        description="The text of the response.",
    )

aiperf.common.models.service_models

ServiceRunInfo

Bases: AIPerfBaseModel

Base model for tracking service run information.

Source code in aiperf/common/models/service_models.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class ServiceRunInfo(AIPerfBaseModel):
    """Base model for tracking service run information."""

    service_type: ServiceTypeT = Field(
        ...,
        description="The type of service",
    )
    registration_status: ServiceRegistrationStatus = Field(
        ...,
        description="The registration status of the service",
    )
    service_id: str = Field(
        ...,
        description="The ID of the service",
    )
    first_seen: int | None = Field(
        default_factory=time.time_ns,
        description="The first time the service was seen",
    )
    last_seen: int | None = Field(
        default_factory=time.time_ns,
        description="The last time the service was seen",
    )
    state: LifecycleState = Field(
        default=LifecycleState.CREATED,
        description="The current state of the service",
    )

aiperf.common.models.worker_models

WorkerPhaseTaskStats

Bases: AIPerfBaseModel

Stats for the tasks that have been sent to the worker for a given credit phase.

Source code in aiperf/common/models/worker_models.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class WorkerPhaseTaskStats(AIPerfBaseModel):
    """Stats for the tasks that have been sent to the worker for a given credit phase."""

    total: int = Field(
        default=0,
        description="The total number of tasks that have been sent to the worker. "
        "Not all tasks will be completed.",
    )
    failed: int = Field(
        default=0,
        description="The number of tasks that returned an error",
    )
    completed: int = Field(
        default=0,
        description="The number of tasks that were completed successfully",
    )

    @property
    def in_progress(self) -> int:
        """The number of tasks that are currently in progress.

        This is the total number of tasks sent to the worker minus the number of failed and successfully completed tasks.
        """
        return self.total - self.completed - self.failed

in_progress property

The number of tasks that are currently in progress.

This is the total number of tasks sent to the worker minus the number of failed and successfully completed tasks.

aiperf.common.protocols

AIPerfLifecycleProtocol

Bases: TaskManagerProtocol, Protocol

Protocol for AIPerf lifecycle methods. see :class:aiperf.common.mixins.aiperf_lifecycle_mixin.AIPerfLifecycleMixin for more details.

Source code in aiperf/common/protocols.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@runtime_checkable
class AIPerfLifecycleProtocol(TaskManagerProtocol, Protocol):
    """Protocol for AIPerf lifecycle methods.
    see :class:`aiperf.common.mixins.aiperf_lifecycle_mixin.AIPerfLifecycleMixin` for more details.
    """

    @property
    def was_initialized(self) -> bool: ...
    @property
    def was_started(self) -> bool: ...
    @property
    def was_stopped(self) -> bool: ...
    @property
    def is_running(self) -> bool: ...

    initialized_event: asyncio.Event
    started_event: asyncio.Event
    stopped_event: asyncio.Event

    @property
    def state(self) -> LifecycleState: ...

    async def initialize(self) -> None: ...
    async def start(self) -> None: ...
    async def initialize_and_start(self) -> None: ...
    async def stop(self) -> None: ...

CommunicationProtocol

Bases: AIPerfLifecycleProtocol, Protocol

Protocol for the base communication layer. see :class:aiperf.common.comms.base_comms.BaseCommunication for more details.

Source code in aiperf/common/protocols.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
@runtime_checkable
class CommunicationProtocol(AIPerfLifecycleProtocol, Protocol):
    """Protocol for the base communication layer.
    see :class:`aiperf.common.comms.base_comms.BaseCommunication` for more details.
    """

    def get_address(self, address_type: CommAddressType) -> str: ...

    """Get the address for the given address type can be an enum value for lookup, or a string for direct use."""

    def create_client(
        self,
        client_type: CommClientType,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
        max_pull_concurrency: int | None = None,
    ) -> CommunicationClientProtocol:
        """Create a client for the given client type and address, which will be automatically
        started and stopped with the CommunicationProtocol instance."""
        ...

    def create_pub_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> PubClientProtocol:
        """Create a PUB client for the given address, which will be automatically
        started and stopped with the CommunicationProtocol instance."""
        ...

    def create_sub_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> SubClientProtocol:
        """Create a SUB client for the given address, which will be automatically
        started and stopped with the CommunicationProtocol instance."""
        ...

    def create_push_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> PushClientProtocol:
        """Create a PUSH client for the given address, which will be automatically
        started and stopped with the CommunicationProtocol instance."""
        ...

    def create_pull_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
        max_pull_concurrency: int | None = None,
    ) -> PullClientProtocol:
        """Create a PULL client for the given address, which will be automatically
        started and stopped with the CommunicationProtocol instance."""
        ...

    def create_request_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> RequestClientProtocol:
        """Create a REQUEST client for the given address, which will be automatically
        started and stopped with the CommunicationProtocol instance."""
        ...

    def create_reply_client(
        self,
        address: CommAddressType,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> ReplyClientProtocol:
        """Create a REPLY client for the given address, which will be automatically
        started and stopped with the CommunicationProtocol instance."""
        ...

create_client(client_type, address, bind=False, socket_ops=None, max_pull_concurrency=None)

Create a client for the given client type and address, which will be automatically started and stopped with the CommunicationProtocol instance.

Source code in aiperf/common/protocols.py
207
208
209
210
211
212
213
214
215
216
217
def create_client(
    self,
    client_type: CommClientType,
    address: CommAddressType,
    bind: bool = False,
    socket_ops: dict | None = None,
    max_pull_concurrency: int | None = None,
) -> CommunicationClientProtocol:
    """Create a client for the given client type and address, which will be automatically
    started and stopped with the CommunicationProtocol instance."""
    ...

create_pub_client(address, bind=False, socket_ops=None)

Create a PUB client for the given address, which will be automatically started and stopped with the CommunicationProtocol instance.

Source code in aiperf/common/protocols.py
219
220
221
222
223
224
225
226
227
def create_pub_client(
    self,
    address: CommAddressType,
    bind: bool = False,
    socket_ops: dict | None = None,
) -> PubClientProtocol:
    """Create a PUB client for the given address, which will be automatically
    started and stopped with the CommunicationProtocol instance."""
    ...

create_pull_client(address, bind=False, socket_ops=None, max_pull_concurrency=None)

Create a PULL client for the given address, which will be automatically started and stopped with the CommunicationProtocol instance.

Source code in aiperf/common/protocols.py
249
250
251
252
253
254
255
256
257
258
def create_pull_client(
    self,
    address: CommAddressType,
    bind: bool = False,
    socket_ops: dict | None = None,
    max_pull_concurrency: int | None = None,
) -> PullClientProtocol:
    """Create a PULL client for the given address, which will be automatically
    started and stopped with the CommunicationProtocol instance."""
    ...

create_push_client(address, bind=False, socket_ops=None)

Create a PUSH client for the given address, which will be automatically started and stopped with the CommunicationProtocol instance.

Source code in aiperf/common/protocols.py
239
240
241
242
243
244
245
246
247
def create_push_client(
    self,
    address: CommAddressType,
    bind: bool = False,
    socket_ops: dict | None = None,
) -> PushClientProtocol:
    """Create a PUSH client for the given address, which will be automatically
    started and stopped with the CommunicationProtocol instance."""
    ...

create_reply_client(address, bind=False, socket_ops=None)

Create a REPLY client for the given address, which will be automatically started and stopped with the CommunicationProtocol instance.

Source code in aiperf/common/protocols.py
270
271
272
273
274
275
276
277
278
def create_reply_client(
    self,
    address: CommAddressType,
    bind: bool = False,
    socket_ops: dict | None = None,
) -> ReplyClientProtocol:
    """Create a REPLY client for the given address, which will be automatically
    started and stopped with the CommunicationProtocol instance."""
    ...

create_request_client(address, bind=False, socket_ops=None)

Create a REQUEST client for the given address, which will be automatically started and stopped with the CommunicationProtocol instance.

Source code in aiperf/common/protocols.py
260
261
262
263
264
265
266
267
268
def create_request_client(
    self,
    address: CommAddressType,
    bind: bool = False,
    socket_ops: dict | None = None,
) -> RequestClientProtocol:
    """Create a REQUEST client for the given address, which will be automatically
    started and stopped with the CommunicationProtocol instance."""
    ...

create_sub_client(address, bind=False, socket_ops=None)

Create a SUB client for the given address, which will be automatically started and stopped with the CommunicationProtocol instance.

Source code in aiperf/common/protocols.py
229
230
231
232
233
234
235
236
237
def create_sub_client(
    self,
    address: CommAddressType,
    bind: bool = False,
    socket_ops: dict | None = None,
) -> SubClientProtocol:
    """Create a SUB client for the given address, which will be automatically
    started and stopped with the CommunicationProtocol instance."""
    ...

DataExporterProtocol

Bases: Protocol

Protocol for data exporters. Any class implementing this protocol must provide an export method that takes a list of Record objects and handles exporting them appropriately.

Source code in aiperf/common/protocols.py
296
297
298
299
300
301
302
303
304
@runtime_checkable
class DataExporterProtocol(Protocol):
    """
    Protocol for data exporters.
    Any class implementing this protocol must provide an `export` method
    that takes a list of Record objects and handles exporting them appropriately.
    """

    async def export(self) -> None: ...

HooksProtocol

Bases: Protocol

Protocol for hooks methods provided by the HooksMixin.

Source code in aiperf/common/protocols.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
@runtime_checkable
class HooksProtocol(Protocol):
    """Protocol for hooks methods provided by the HooksMixin."""

    def get_hooks(self, *hook_types: HookType, reversed: bool = False) -> list[Hook]:
        """Get the hooks for the given hook type(s), optionally reversed."""
        ...

    async def run_hooks(
        self, *hook_types: HookType, reversed: bool = False, **kwargs
    ) -> None:
        """Run the hooks for the given hook type, waiting for each hook to complete before running the next one.
        If reversed is True, the hooks will be run in reverse order. This is useful for stop/cleanup starting with
        the children and ending with the parent.
        """
        ...

get_hooks(*hook_types, reversed=False)

Get the hooks for the given hook type(s), optionally reversed.

Source code in aiperf/common/protocols.py
311
312
313
def get_hooks(self, *hook_types: HookType, reversed: bool = False) -> list[Hook]:
    """Get the hooks for the given hook type(s), optionally reversed."""
    ...

run_hooks(*hook_types, reversed=False, **kwargs) async

Run the hooks for the given hook type, waiting for each hook to complete before running the next one. If reversed is True, the hooks will be run in reverse order. This is useful for stop/cleanup starting with the children and ending with the parent.

Source code in aiperf/common/protocols.py
315
316
317
318
319
320
321
322
async def run_hooks(
    self, *hook_types: HookType, reversed: bool = False, **kwargs
) -> None:
    """Run the hooks for the given hook type, waiting for each hook to complete before running the next one.
    If reversed is True, the hooks will be run in reverse order. This is useful for stop/cleanup starting with
    the children and ending with the parent.
    """
    ...

InferenceClientProtocol

Bases: Protocol

Protocol for an inference server client.

This protocol defines the methods that must be implemented by any inference server client implementation that is compatible with the AIPerf framework.

Source code in aiperf/common/protocols.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
@runtime_checkable
class InferenceClientProtocol(Protocol):
    """Protocol for an inference server client.

    This protocol defines the methods that must be implemented by any inference server client
    implementation that is compatible with the AIPerf framework.
    """

    def __init__(self, model_endpoint: ModelEndpointInfoT) -> None:
        """Create a new inference server client based on the provided configuration."""
        ...

    async def initialize(self) -> None:
        """Initialize the inference server client in an asynchronous context."""
        ...

    async def send_request(
        self,
        model_endpoint: ModelEndpointInfoT,
        payload: RequestInputT,
    ) -> RequestRecord:
        """Send a request to the inference server.

        This method is used to send a request to the inference server.

        Args:
            model_endpoint: The endpoint to send the request to.
            payload: The payload to send to the inference server.
        Returns:
            The raw response from the inference server.
        """
        ...

    async def close(self) -> None:
        """Close the client."""
        ...

__init__(model_endpoint)

Create a new inference server client based on the provided configuration.

Source code in aiperf/common/protocols.py
333
334
335
def __init__(self, model_endpoint: ModelEndpointInfoT) -> None:
    """Create a new inference server client based on the provided configuration."""
    ...

close() async

Close the client.

Source code in aiperf/common/protocols.py
358
359
360
async def close(self) -> None:
    """Close the client."""
    ...

initialize() async

Initialize the inference server client in an asynchronous context.

Source code in aiperf/common/protocols.py
337
338
339
async def initialize(self) -> None:
    """Initialize the inference server client in an asynchronous context."""
    ...

send_request(model_endpoint, payload) async

Send a request to the inference server.

This method is used to send a request to the inference server.

Parameters:

Name Type Description Default
model_endpoint ModelEndpointInfoT

The endpoint to send the request to.

required
payload RequestInputT

The payload to send to the inference server.

required

Returns: The raw response from the inference server.

Source code in aiperf/common/protocols.py
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
async def send_request(
    self,
    model_endpoint: ModelEndpointInfoT,
    payload: RequestInputT,
) -> RequestRecord:
    """Send a request to the inference server.

    This method is used to send a request to the inference server.

    Args:
        model_endpoint: The endpoint to send the request to.
        payload: The payload to send to the inference server.
    Returns:
        The raw response from the inference server.
    """
    ...

MessageBusClientProtocol

Bases: PubClientProtocol, SubClientProtocol, Protocol

A message bus client is a client that can publish and subscribe to messages on the event bus/message bus.

Source code in aiperf/common/protocols.py
281
282
283
284
285
286
287
288
@runtime_checkable
class MessageBusClientProtocol(PubClientProtocol, SubClientProtocol, Protocol):
    """A message bus client is a client that can publish and subscribe to messages
    on the event bus/message bus."""

    comms: CommunicationProtocol
    sub_client: SubClientProtocol
    pub_client: PubClientProtocol

RecordProcessorProtocol

Bases: Protocol

Protocol for a record processor that processes the incoming records and returns the results of the post processing.

Source code in aiperf/common/protocols.py
448
449
450
451
452
453
454
@runtime_checkable
class RecordProcessorProtocol(Protocol):
    """Protocol for a record processor that processes the incoming records and returns the results of the post processing."""

    async def process_record(
        self, record: ParsedResponseRecord
    ) -> "MetricRecordDict": ...

RequestConverterProtocol

Bases: Protocol

Protocol for a request converter that converts a raw request to a formatted request for the inference server.

Source code in aiperf/common/protocols.py
375
376
377
378
379
380
381
382
383
@runtime_checkable
class RequestConverterProtocol(Protocol):
    """Protocol for a request converter that converts a raw request to a formatted request for the inference server."""

    async def format_payload(
        self, model_endpoint: ModelEndpointInfoT, turn: Turn
    ) -> RequestOutputT:
        """Format the turn for the inference server."""
        ...

format_payload(model_endpoint, turn) async

Format the turn for the inference server.

Source code in aiperf/common/protocols.py
379
380
381
382
383
async def format_payload(
    self, model_endpoint: ModelEndpointInfoT, turn: Turn
) -> RequestOutputT:
    """Format the turn for the inference server."""
    ...

ResponseExtractorProtocol

Bases: Protocol

Protocol for a response extractor that extracts the response data from a raw inference server response and converts it to a list of ResponseData objects.

Source code in aiperf/common/protocols.py
363
364
365
366
367
368
369
370
371
372
@runtime_checkable
class ResponseExtractorProtocol(Protocol):
    """Protocol for a response extractor that extracts the response data from a raw inference server
    response and converts it to a list of ResponseData objects."""

    async def extract_response_data(
        self, record: RequestRecord, tokenizer: Tokenizer | None
    ) -> list[ResponseData]:
        """Extract the response data from a raw inference server response and convert it to a list of ResponseData objects."""
        ...

extract_response_data(record, tokenizer) async

Extract the response data from a raw inference server response and convert it to a list of ResponseData objects.

Source code in aiperf/common/protocols.py
368
369
370
371
372
async def extract_response_data(
    self, record: RequestRecord, tokenizer: Tokenizer | None
) -> list[ResponseData]:
    """Extract the response data from a raw inference server response and convert it to a list of ResponseData objects."""
    ...

ResultsProcessorProtocol

Bases: Protocol

Protocol for a results processor that processes the results of multiple record processors, and provides the ability to summarize the results.

Source code in aiperf/common/protocols.py
457
458
459
460
461
462
463
464
465
466
@runtime_checkable
class ResultsProcessorProtocol(Protocol):
    """Protocol for a results processor that processes the results of multiple
    record processors, and provides the ability to summarize the results."""

    async def process_result(
        self, result: dict[MetricTagT, "MetricValueTypeT"]
    ) -> None: ...

    async def summarize(self) -> list["MetricResult"]: ...

ServiceManagerProtocol

Bases: AIPerfLifecycleProtocol, Protocol

Protocol for a service manager that manages the running of services using the specific ServiceRunType. Abstracts away the details of service deployment and management. see :class:aiperf.controller.base_service_manager.BaseServiceManager for more details.

Source code in aiperf/common/protocols.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
@runtime_checkable
class ServiceManagerProtocol(AIPerfLifecycleProtocol, Protocol):
    """Protocol for a service manager that manages the running of services using the specific ServiceRunType.
    Abstracts away the details of service deployment and management.
    see :class:`aiperf.controller.base_service_manager.BaseServiceManager` for more details.
    """

    def __init__(
        self,
        required_services: dict[ServiceTypeT, int],
        service_config: "ServiceConfig",
        user_config: "UserConfig",
        log_queue: "multiprocessing.Queue | None" = None,
    ): ...

    required_services: dict[ServiceTypeT, int]
    service_map: dict[ServiceTypeT, list[ServiceRunInfo]]
    service_id_map: dict[str, ServiceRunInfo]

    async def run_service(
        self, service_type: ServiceTypeT, num_replicas: int = 1
    ) -> None: ...

    async def run_services(self, service_types: dict[ServiceTypeT, int]) -> None: ...
    async def run_required_services(self) -> None: ...
    async def shutdown_all_services(self) -> list[BaseException | None]: ...
    async def kill_all_services(self) -> list[BaseException | None]: ...
    async def stop_service(
        self, service_type: ServiceTypeT, service_id: str | None = None
    ) -> list[BaseException | None]: ...
    async def stop_services_by_type(
        self, service_types: list[ServiceTypeT]
    ) -> list[BaseException | None]: ...
    async def wait_for_all_services_registration(
        self,
        stop_event: asyncio.Event,
        timeout_seconds: float = DEFAULT_SERVICE_REGISTRATION_TIMEOUT,
    ) -> None: ...

    async def wait_for_all_services_start(
        self,
        stop_event: asyncio.Event,
        timeout_seconds: float = DEFAULT_SERVICE_START_TIMEOUT,
    ) -> None: ...

ServiceProtocol

Bases: MessageBusClientProtocol, Protocol

Protocol for a service. Essentially a MessageBusClientProtocol with a service_type and service_id attributes.

Source code in aiperf/common/protocols.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
@runtime_checkable
class ServiceProtocol(MessageBusClientProtocol, Protocol):
    """Protocol for a service. Essentially a MessageBusClientProtocol with a service_type and service_id attributes."""

    def __init__(
        self,
        user_config: "UserConfig",
        service_config: "ServiceConfig",
        service_id: str | None = None,
        **kwargs,
    ) -> None: ...

    service_type: ServiceTypeT
    service_id: str

aiperf.common.tokenizer

Tokenizer

This class provides a simplified interface for using Huggingface tokenizers, with default arguments for common operations.

Source code in aiperf/common/tokenizer.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class Tokenizer:
    """
    This class provides a simplified interface for using Huggingface
    tokenizers, with default arguments for common operations.
    """

    def __init__(self) -> None:
        """
        Initialize the tokenizer with default values for call, encode, and decode.
        """
        self._tokenizer = None
        self._call_args = {"add_special_tokens": False}
        self._encode_args = {"add_special_tokens": False}
        self._decode_args = {"skip_special_tokens": True}

    @classmethod
    def from_pretrained(
        cls,
        name: str,
        trust_remote_code: bool = False,
        revision: str = "main",
    ) -> "Tokenizer":
        """
        Factory to load a tokenizer for the given pretrained model name.

        Args:
            name: The name or path of the pretrained tokenizer model.
            trust_remote_code: Whether to trust remote code when loading the tokenizer.
            revision: The specific model version to use.
        """
        try:
            tokenizer_cls = cls()
            tokenizer_cls._tokenizer = AutoTokenizer.from_pretrained(
                name, trust_remote_code=trust_remote_code, revision=revision
            )
        except Exception as e:
            raise InitializationError(e) from e
        return tokenizer_cls

    def __call__(self, text, **kwargs) -> "BatchEncoding":
        """
        Call the underlying Huggingface tokenizer with default arguments,
        which can be overridden by kwargs.

        Args:
            text: The input text to tokenize.

        Returns:
            A BatchEncoding object containing the tokenized output.
        """
        if self._tokenizer is None:
            raise NotInitializedError("Tokenizer is not initialized.")
        return self._tokenizer(text, **{**self._call_args, **kwargs})

    def encode(self, text, **kwargs) -> list[int]:
        """
        Encode the input text into a list of token IDs.

        This method calls the underlying Huggingface tokenizer's encode
        method with default arguments, which can be overridden by kwargs.

        Args:
            text: The input text to encode.

        Returns:
            A list of token IDs.
        """
        if self._tokenizer is None:
            raise NotInitializedError("Tokenizer is not initialized.")
        return self._tokenizer.encode(text, **{**self._encode_args, **kwargs})

    def decode(self, token_ids, **kwargs) -> str:
        """
        Decode a list of token IDs back into a string.

        This method calls the underlying Huggingface tokenizer's decode
        method with default arguments, which can be overridden by kwargs.

        Args:
            token_ids: A list of token IDs to decode.

        Returns:
            The decoded string.
        """
        if self._tokenizer is None:
            raise NotInitializedError("Tokenizer is not initialized.")
        return self._tokenizer.decode(token_ids, **{**self._decode_args, **kwargs})

    @property
    def bos_token_id(self) -> int:
        """
        Return the beginning-of-sequence (BOS) token ID.
        """
        if self._tokenizer is None:
            raise NotInitializedError("Tokenizer is not initialized.")
        return self._tokenizer.bos_token_id

    def __repr__(self) -> str:
        """
        Return a string representation of the underlying tokenizer.

        Returns:
            The string representation of the tokenizer.
        """
        return self._tokenizer.__repr__()

    def __str__(self) -> str:
        """
        Return a user-friendly string representation of the underlying tokenizer.

        Returns:
            The string representation of the tokenizer.
        """
        return self._tokenizer.__str__()

bos_token_id property

Return the beginning-of-sequence (BOS) token ID.

__call__(text, **kwargs)

Call the underlying Huggingface tokenizer with default arguments, which can be overridden by kwargs.

Parameters:

Name Type Description Default
text

The input text to tokenize.

required

Returns:

Type Description
BatchEncoding

A BatchEncoding object containing the tokenized output.

Source code in aiperf/common/tokenizer.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def __call__(self, text, **kwargs) -> "BatchEncoding":
    """
    Call the underlying Huggingface tokenizer with default arguments,
    which can be overridden by kwargs.

    Args:
        text: The input text to tokenize.

    Returns:
        A BatchEncoding object containing the tokenized output.
    """
    if self._tokenizer is None:
        raise NotInitializedError("Tokenizer is not initialized.")
    return self._tokenizer(text, **{**self._call_args, **kwargs})

__init__()

Initialize the tokenizer with default values for call, encode, and decode.

Source code in aiperf/common/tokenizer.py
28
29
30
31
32
33
34
35
def __init__(self) -> None:
    """
    Initialize the tokenizer with default values for call, encode, and decode.
    """
    self._tokenizer = None
    self._call_args = {"add_special_tokens": False}
    self._encode_args = {"add_special_tokens": False}
    self._decode_args = {"skip_special_tokens": True}

__repr__()

Return a string representation of the underlying tokenizer.

Returns:

Type Description
str

The string representation of the tokenizer.

Source code in aiperf/common/tokenizer.py
119
120
121
122
123
124
125
126
def __repr__(self) -> str:
    """
    Return a string representation of the underlying tokenizer.

    Returns:
        The string representation of the tokenizer.
    """
    return self._tokenizer.__repr__()

__str__()

Return a user-friendly string representation of the underlying tokenizer.

Returns:

Type Description
str

The string representation of the tokenizer.

Source code in aiperf/common/tokenizer.py
128
129
130
131
132
133
134
135
def __str__(self) -> str:
    """
    Return a user-friendly string representation of the underlying tokenizer.

    Returns:
        The string representation of the tokenizer.
    """
    return self._tokenizer.__str__()

decode(token_ids, **kwargs)

Decode a list of token IDs back into a string.

This method calls the underlying Huggingface tokenizer's decode method with default arguments, which can be overridden by kwargs.

Parameters:

Name Type Description Default
token_ids

A list of token IDs to decode.

required

Returns:

Type Description
str

The decoded string.

Source code in aiperf/common/tokenizer.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def decode(self, token_ids, **kwargs) -> str:
    """
    Decode a list of token IDs back into a string.

    This method calls the underlying Huggingface tokenizer's decode
    method with default arguments, which can be overridden by kwargs.

    Args:
        token_ids: A list of token IDs to decode.

    Returns:
        The decoded string.
    """
    if self._tokenizer is None:
        raise NotInitializedError("Tokenizer is not initialized.")
    return self._tokenizer.decode(token_ids, **{**self._decode_args, **kwargs})

encode(text, **kwargs)

Encode the input text into a list of token IDs.

This method calls the underlying Huggingface tokenizer's encode method with default arguments, which can be overridden by kwargs.

Parameters:

Name Type Description Default
text

The input text to encode.

required

Returns:

Type Description
list[int]

A list of token IDs.

Source code in aiperf/common/tokenizer.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def encode(self, text, **kwargs) -> list[int]:
    """
    Encode the input text into a list of token IDs.

    This method calls the underlying Huggingface tokenizer's encode
    method with default arguments, which can be overridden by kwargs.

    Args:
        text: The input text to encode.

    Returns:
        A list of token IDs.
    """
    if self._tokenizer is None:
        raise NotInitializedError("Tokenizer is not initialized.")
    return self._tokenizer.encode(text, **{**self._encode_args, **kwargs})

from_pretrained(name, trust_remote_code=False, revision='main') classmethod

Factory to load a tokenizer for the given pretrained model name.

Parameters:

Name Type Description Default
name str

The name or path of the pretrained tokenizer model.

required
trust_remote_code bool

Whether to trust remote code when loading the tokenizer.

False
revision str

The specific model version to use.

'main'
Source code in aiperf/common/tokenizer.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@classmethod
def from_pretrained(
    cls,
    name: str,
    trust_remote_code: bool = False,
    revision: str = "main",
) -> "Tokenizer":
    """
    Factory to load a tokenizer for the given pretrained model name.

    Args:
        name: The name or path of the pretrained tokenizer model.
        trust_remote_code: Whether to trust remote code when loading the tokenizer.
        revision: The specific model version to use.
    """
    try:
        tokenizer_cls = cls()
        tokenizer_cls._tokenizer = AutoTokenizer.from_pretrained(
            name, trust_remote_code=trust_remote_code, revision=revision
        )
    except Exception as e:
        raise InitializationError(e) from e
    return tokenizer_cls

aiperf.common.types

This module defines common used alias types for AIPerf. This both helps prevent circular imports and helps with type hinting.

aiperf.common.utils

call_all_functions(funcs, *args, **kwargs) async

Call all functions in the list with the given name.

Parameters:

Name Type Description Default
obj

The object to call the functions on.

required
func_names

The names of the functions to call.

required
*args

The arguments to pass to the functions.

()
**kwargs

The keyword arguments to pass to the functions.

{}

Raises:

Type Description
AIPerfMultiError

If any of the functions raise an exception.

Source code in aiperf/common/utils.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
async def call_all_functions(funcs: list[Callable], *args, **kwargs) -> None:
    """Call all functions in the list with the given name.

    Args:
        obj: The object to call the functions on.
        func_names: The names of the functions to call.
        *args: The arguments to pass to the functions.
        **kwargs: The keyword arguments to pass to the functions.

    Raises:
        AIPerfMultiError: If any of the functions raise an exception.
    """

    exceptions = []
    for func in funcs:
        try:
            if inspect.iscoroutinefunction(func):
                await func(*args, **kwargs)
            else:
                func(*args, **kwargs)
        except Exception as e:
            # TODO: error handling, logging
            traceback.print_exc()
            exceptions.append(e)

    if len(exceptions) > 0:
        raise AIPerfMultiError("Errors calling functions", exceptions)

call_all_functions_self(self_, funcs, *args, **kwargs) async

Call all functions in the list with the given name.

Parameters:

Name Type Description Default
obj

The object to call the functions on.

required
func_names

The names of the functions to call.

required
*args

The arguments to pass to the functions.

()
**kwargs

The keyword arguments to pass to the functions.

{}

Raises:

Type Description
AIPerfMultiError

If any of the functions raise an exception.

Source code in aiperf/common/utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
async def call_all_functions_self(
    self_: object, funcs: list[Callable], *args, **kwargs
) -> None:
    """Call all functions in the list with the given name.

    Args:
        obj: The object to call the functions on.
        func_names: The names of the functions to call.
        *args: The arguments to pass to the functions.
        **kwargs: The keyword arguments to pass to the functions.

    Raises:
        AIPerfMultiError: If any of the functions raise an exception.
    """

    exceptions = []
    for func in funcs:
        try:
            if inspect.iscoroutinefunction(func):
                await func(self_, *args, **kwargs)
            else:
                func(self_, *args, **kwargs)
        except Exception as e:
            # TODO: error handling, logging
            traceback.print_exc()
            exceptions.append(e)

    if len(exceptions) > 0:
        raise AIPerfMultiError("Errors calling functions", exceptions)

load_json_str(json_str, func=lambda x: x)

Deserializes JSON encoded string into Python object.

Parameters:

Name Type Description Default
- json_str

string JSON encoded string

required
- func

callable A function that takes deserialized JSON object. This can be used to run validation checks on the object. Defaults to identity function.

required
Source code in aiperf/common/utils.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def load_json_str(json_str: str, func: Callable = lambda x: x) -> dict[str, Any]:
    """
    Deserializes JSON encoded string into Python object.

    Args:
      - json_str: string
          JSON encoded string
      - func: callable
          A function that takes deserialized JSON object. This can be used to
          run validation checks on the object. Defaults to identity function.
    """
    try:
        # Note: orjson may not parse JSON the same way as Python's standard json library,
        # notably being stricter on UTF-8 conformance.
        # Refer to https://github.com/ijl/orjson?tab=readme-ov-file#str for details.
        return func(orjson.loads(json_str))
    except orjson.JSONDecodeError:
        snippet = json_str[:200] + ("..." if len(json_str) > 200 else "")
        _logger.error(f"Failed to parse JSON string: '{snippet}'")
        raise

yield_to_event_loop() async

Yield to the event loop. This forces the current coroutine to yield and allow other coroutines to run, preventing starvation. Use this when you do not want to delay your coroutine via sleep, but still want to allow other coroutines to run if there is a potential for an infinite loop.

Source code in aiperf/common/utils.py
101
102
103
104
105
106
107
async def yield_to_event_loop() -> None:
    """Yield to the event loop. This forces the current coroutine to yield and allow
    other coroutines to run, preventing starvation. Use this when you do not want to
    delay your coroutine via sleep, but still want to allow other coroutines to run if
    there is a potential for an infinite loop.
    """
    await asyncio.sleep(0)

aiperf.controller.base_service_manager

BaseServiceManager

Bases: AIPerfLifecycleMixin, ABC

Base class for service managers. It provides a common interface for managing services.

Source code in aiperf/controller/base_service_manager.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@implements_protocol(ServiceManagerProtocol)
class BaseServiceManager(AIPerfLifecycleMixin, ABC):
    """
    Base class for service managers. It provides a common interface for managing services.
    """

    def __init__(
        self,
        required_services: dict[ServiceTypeT, int],
        service_config: ServiceConfig,
        user_config: UserConfig,
        **kwargs,
    ):
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            **kwargs,
        )
        self.required_services = required_services
        self.service_config = service_config
        self.user_config = user_config
        self.kwargs = kwargs
        # Maps to track service information
        self.service_map: dict[ServiceTypeT, list[ServiceRunInfo]] = {}

        # Create service ID map for component lookups
        self.service_id_map: dict[str, ServiceRunInfo] = {}

    @on_start
    async def _start_service_manager(self) -> None:
        await self.run_required_services()

    @on_stop
    async def _stop_service_manager(self) -> None:
        await self.shutdown_all_services()

    async def run_services(
        self, service_types: dict[ServiceTypeT, int]
    ) -> list[BaseException | None]:
        return await asyncio.gather(
            *[
                self.run_service(service_type, num_replicas)
                for service_type, num_replicas in service_types.items()
            ],
            return_exceptions=True,
        )

    @abstractmethod
    async def stop_service(
        self, service_type: ServiceTypeT, service_id: str | None = None
    ) -> list[BaseException | None]: ...

    # TODO: This stuff needs some major cleanup

    async def stop_services_by_type(
        self, service_types: list[ServiceTypeT]
    ) -> list[BaseException | None]:
        """Stop a set of services."""
        results = await asyncio.gather(
            *[self.stop_service(service_type) for service_type in service_types],
            return_exceptions=True,
        )
        output: list[BaseException | None] = []
        for result in results:
            if isinstance(result, list):
                output.extend(result)
            else:
                output.append(result)
        return output

    async def run_required_services(self) -> None:
        await self.run_services(self.required_services)

    @abstractmethod
    async def run_service(
        self, service_type: ServiceTypeT, num_replicas: int = 1
    ) -> None:
        pass

    @abstractmethod
    async def shutdown_all_services(self) -> list[BaseException | None]:
        pass

    @abstractmethod
    async def kill_all_services(self) -> list[BaseException | None]:
        pass

    @abstractmethod
    async def wait_for_all_services_registration(
        self,
        stop_event: asyncio.Event,
        timeout_seconds: float = DEFAULT_SERVICE_REGISTRATION_TIMEOUT,
    ) -> None:
        pass

    @abstractmethod
    async def wait_for_all_services_start(
        self,
        stop_event: asyncio.Event,
        timeout_seconds: float = DEFAULT_SERVICE_START_TIMEOUT,
    ) -> None:
        pass

stop_services_by_type(service_types) async

Stop a set of services.

Source code in aiperf/controller/base_service_manager.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
async def stop_services_by_type(
    self, service_types: list[ServiceTypeT]
) -> list[BaseException | None]:
    """Stop a set of services."""
    results = await asyncio.gather(
        *[self.stop_service(service_type) for service_type in service_types],
        return_exceptions=True,
    )
    output: list[BaseException | None] = []
    for result in results:
        if isinstance(result, list):
            output.extend(result)
        else:
            output.append(result)
    return output

aiperf.controller.kubernetes_service_manager

KubernetesServiceManager

Bases: BaseServiceManager

Service Manager for starting and stopping services in a Kubernetes cluster.

Source code in aiperf/controller/kubernetes_service_manager.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@implements_protocol(ServiceManagerProtocol)
@ServiceManagerFactory.register(ServiceRunType.KUBERNETES)
class KubernetesServiceManager(BaseServiceManager):
    """
    Service Manager for starting and stopping services in a Kubernetes cluster.
    """

    def __init__(
        self,
        required_services: dict[ServiceTypeT, int],
        service_config: ServiceConfig,
        user_config: UserConfig,
        **kwargs,
    ):
        super().__init__(required_services, service_config, user_config, **kwargs)

    async def run_service(
        self, service_type: ServiceTypeT, num_replicas: int = 1
    ) -> None:
        """Run a service as a Kubernetes pod."""
        self.logger.debug(f"Running service {service_type} as a Kubernetes pod")
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.run_service not implemented"
        )

    async def shutdown_all_services(self) -> list[BaseException | None]:
        """Stop all required services as Kubernetes pods."""
        self.logger.debug("Stopping all required services as Kubernetes pods")
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.stop_all_services not implemented"
        )

    async def kill_all_services(self) -> list[BaseException | None]:
        """Kill all required services as Kubernetes pods."""
        self.logger.debug("Killing all required services as Kubernetes pods")
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.kill_all_services not implemented"
        )

    async def wait_for_all_services_registration(
        self,
        stop_event: asyncio.Event,
        timeout_seconds: float = DEFAULT_SERVICE_REGISTRATION_TIMEOUT,
    ) -> None:
        """Wait for all required services to be registered in Kubernetes."""
        self.logger.debug(
            "Waiting for all required services to be registered in Kubernetes"
        )
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.wait_for_all_services_registration not implemented"
        )

    async def wait_for_all_services_start(
        self,
        stop_event: asyncio.Event,
        timeout_seconds: float = DEFAULT_SERVICE_START_TIMEOUT,
    ) -> None:
        """Wait for all required services to be started in Kubernetes."""
        self.logger.debug(
            "Waiting for all required services to be started in Kubernetes"
        )
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.wait_for_all_services_start not implemented"
        )

kill_all_services() async

Kill all required services as Kubernetes pods.

Source code in aiperf/controller/kubernetes_service_manager.py
62
63
64
65
66
67
68
async def kill_all_services(self) -> list[BaseException | None]:
    """Kill all required services as Kubernetes pods."""
    self.logger.debug("Killing all required services as Kubernetes pods")
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.kill_all_services not implemented"
    )

run_service(service_type, num_replicas=1) async

Run a service as a Kubernetes pod.

Source code in aiperf/controller/kubernetes_service_manager.py
44
45
46
47
48
49
50
51
52
async def run_service(
    self, service_type: ServiceTypeT, num_replicas: int = 1
) -> None:
    """Run a service as a Kubernetes pod."""
    self.logger.debug(f"Running service {service_type} as a Kubernetes pod")
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.run_service not implemented"
    )

shutdown_all_services() async

Stop all required services as Kubernetes pods.

Source code in aiperf/controller/kubernetes_service_manager.py
54
55
56
57
58
59
60
async def shutdown_all_services(self) -> list[BaseException | None]:
    """Stop all required services as Kubernetes pods."""
    self.logger.debug("Stopping all required services as Kubernetes pods")
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.stop_all_services not implemented"
    )

wait_for_all_services_registration(stop_event, timeout_seconds=DEFAULT_SERVICE_REGISTRATION_TIMEOUT) async

Wait for all required services to be registered in Kubernetes.

Source code in aiperf/controller/kubernetes_service_manager.py
70
71
72
73
74
75
76
77
78
79
80
81
82
async def wait_for_all_services_registration(
    self,
    stop_event: asyncio.Event,
    timeout_seconds: float = DEFAULT_SERVICE_REGISTRATION_TIMEOUT,
) -> None:
    """Wait for all required services to be registered in Kubernetes."""
    self.logger.debug(
        "Waiting for all required services to be registered in Kubernetes"
    )
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.wait_for_all_services_registration not implemented"
    )

wait_for_all_services_start(stop_event, timeout_seconds=DEFAULT_SERVICE_START_TIMEOUT) async

Wait for all required services to be started in Kubernetes.

Source code in aiperf/controller/kubernetes_service_manager.py
84
85
86
87
88
89
90
91
92
93
94
95
96
async def wait_for_all_services_start(
    self,
    stop_event: asyncio.Event,
    timeout_seconds: float = DEFAULT_SERVICE_START_TIMEOUT,
) -> None:
    """Wait for all required services to be started in Kubernetes."""
    self.logger.debug(
        "Waiting for all required services to be started in Kubernetes"
    )
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.wait_for_all_services_start not implemented"
    )

ServiceKubernetesRunInfo

Bases: BaseModel

Information about a service running in a Kubernetes pod.

Source code in aiperf/controller/kubernetes_service_manager.py
20
21
22
23
24
25
class ServiceKubernetesRunInfo(BaseModel):
    """Information about a service running in a Kubernetes pod."""

    pod_name: str
    node_name: str
    namespace: str

aiperf.controller.multiprocess_service_manager

MultiProcessRunInfo

Bases: BaseModel

Information about a service running as a multiprocessing process.

Source code in aiperf/controller/multiprocess_service_manager.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class MultiProcessRunInfo(BaseModel):
    """Information about a service running as a multiprocessing process."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    process: Process | SpawnProcess | ForkProcess | None = Field(default=None)
    service_type: ServiceTypeT = Field(
        ...,
        description="Type of service running in the process",
    )
    service_id: str = Field(
        ...,
        description="ID of the service running in the process",
    )

MultiProcessServiceManager

Bases: BaseServiceManager

Service Manager for starting and stopping services as multiprocessing processes.

Source code in aiperf/controller/multiprocess_service_manager.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
@implements_protocol(ServiceManagerProtocol)
@ServiceManagerFactory.register(ServiceRunType.MULTIPROCESSING)
class MultiProcessServiceManager(BaseServiceManager):
    """
    Service Manager for starting and stopping services as multiprocessing processes.
    """

    def __init__(
        self,
        required_services: dict[ServiceTypeT, int],
        service_config: ServiceConfig,
        user_config: UserConfig,
        log_queue: "multiprocessing.Queue | None" = None,
        **kwargs,
    ):
        super().__init__(required_services, service_config, user_config, **kwargs)
        self.multi_process_info: list[MultiProcessRunInfo] = []
        self.log_queue = log_queue

    async def run_service(
        self, service_type: ServiceTypeT, num_replicas: int = 1
    ) -> None:
        """Run a service with the given number of replicas."""
        service_class = ServiceFactory.get_class_from_type(service_type)

        for _ in range(num_replicas):
            service_id = f"{service_type}_{uuid.uuid4().hex[:8]}"
            process = Process(
                target=bootstrap_and_run_service,
                name=f"{service_type}_process",
                kwargs={
                    "service_class": service_class,
                    "service_id": service_id,
                    "service_config": self.service_config,
                    "user_config": self.user_config,
                    "log_queue": self.log_queue,
                },
                daemon=True,
            )

            process.start()

            self.debug(
                lambda pid=process.pid,
                type=service_type: f"Service {type} started as process (pid: {pid})"
            )

            self.multi_process_info.append(
                MultiProcessRunInfo(
                    process=process,
                    service_type=service_type,
                    service_id=service_id,
                )
            )

    async def stop_service(
        self, service_type: ServiceTypeT, service_id: str | None = None
    ) -> list[BaseException | None]:
        self.debug(lambda: f"Stopping {service_type} process(es) with id: {service_id}")
        tasks = []
        for info in list(self.multi_process_info):
            if info.service_type == service_type and (
                service_id is None or info.service_id == service_id
            ):
                task = asyncio.create_task(self._wait_for_process(info))
                task.add_done_callback(
                    lambda _, info=info: self.multi_process_info.remove(info)
                )
                tasks.append(task)
        return await asyncio.gather(*tasks, return_exceptions=True)

    async def shutdown_all_services(self) -> list[BaseException | None]:
        """Stop all required services as multiprocessing processes."""
        self.debug("Stopping all service processes")

        # Wait for all to finish in parallel
        return await asyncio.gather(
            *[self._wait_for_process(info) for info in self.multi_process_info],
            return_exceptions=True,
        )

    async def kill_all_services(self) -> list[BaseException | None]:
        """Kill all required services as multiprocessing processes."""
        self.debug("Killing all service processes")

        # Kill all processes
        for info in self.multi_process_info:
            if info.process:
                info.process.kill()

        # Wait for all to finish in parallel
        return await asyncio.gather(
            *[self._wait_for_process(info) for info in self.multi_process_info],
            return_exceptions=True,
        )

    async def wait_for_all_services_registration(
        self,
        stop_event: asyncio.Event,
        timeout_seconds: float = DEFAULT_SERVICE_REGISTRATION_TIMEOUT,
    ) -> None:
        """Wait for all required services to be registered.

        Args:
            stop_event: Event to check if operation should be cancelled
            timeout_seconds: Maximum time to wait in seconds

        Raises:
            Exception if any service failed to register, None otherwise
        """
        self.debug("Waiting for all required services to register...")

        # Get the set of required service types for checking completion
        required_types = set(self.required_services.keys())

        # TODO: Can this be done better by using asyncio.Event()?

        async def _wait_for_registration():
            while not stop_event.is_set():
                # Get all registered service types from the id map
                registered_types = {
                    service_info.service_type
                    for service_info in self.service_id_map.values()
                    if service_info.registration_status
                    == ServiceRegistrationStatus.REGISTERED
                }

                # Check if all required types are registered
                if required_types.issubset(registered_types):
                    return

                # Wait a bit before checking again
                await asyncio.sleep(0.5)

        try:
            await asyncio.wait_for(_wait_for_registration(), timeout=timeout_seconds)
        except asyncio.TimeoutError as e:
            # Log which services didn't register in time
            registered_types_set = set(
                service_info.service_type
                for service_info in self.service_id_map.values()
                if service_info.registration_status
                == ServiceRegistrationStatus.REGISTERED
            )

            for service_type in required_types:
                if service_type not in registered_types_set:
                    self.error(
                        f"Service {service_type} failed to register within timeout"
                    )

            raise AIPerfError("Some services failed to register within timeout") from e

    async def _wait_for_process(self, info: MultiProcessRunInfo) -> None:
        """Wait for a process to terminate with timeout handling."""
        if not info.process or not info.process.is_alive():
            return

        try:
            info.process.terminate()
            await asyncio.to_thread(
                info.process.join, timeout=TASK_CANCEL_TIMEOUT_SHORT
            )
            self.debug(
                f"Service {info.service_type} process stopped (pid: {info.process.pid})"
            )
        except asyncio.TimeoutError:
            self.warning(
                f"Service {info.service_type} process (pid: {info.process.pid}) did not terminate gracefully, killing"
            )
            info.process.kill()

    async def wait_for_all_services_start(
        self,
        stop_event: asyncio.Event,
        timeout_seconds: float = DEFAULT_SERVICE_START_TIMEOUT,
    ) -> None:
        """Wait for all required services to be started."""
        self.debug("Waiting for all required services to start...")
        self.warning(
            "Waiting for all required services to start is not implemented for multiprocessing"
        )

kill_all_services() async

Kill all required services as multiprocessing processes.

Source code in aiperf/controller/multiprocess_service_manager.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
async def kill_all_services(self) -> list[BaseException | None]:
    """Kill all required services as multiprocessing processes."""
    self.debug("Killing all service processes")

    # Kill all processes
    for info in self.multi_process_info:
        if info.process:
            info.process.kill()

    # Wait for all to finish in parallel
    return await asyncio.gather(
        *[self._wait_for_process(info) for info in self.multi_process_info],
        return_exceptions=True,
    )

run_service(service_type, num_replicas=1) async

Run a service with the given number of replicas.

Source code in aiperf/controller/multiprocess_service_manager.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
async def run_service(
    self, service_type: ServiceTypeT, num_replicas: int = 1
) -> None:
    """Run a service with the given number of replicas."""
    service_class = ServiceFactory.get_class_from_type(service_type)

    for _ in range(num_replicas):
        service_id = f"{service_type}_{uuid.uuid4().hex[:8]}"
        process = Process(
            target=bootstrap_and_run_service,
            name=f"{service_type}_process",
            kwargs={
                "service_class": service_class,
                "service_id": service_id,
                "service_config": self.service_config,
                "user_config": self.user_config,
                "log_queue": self.log_queue,
            },
            daemon=True,
        )

        process.start()

        self.debug(
            lambda pid=process.pid,
            type=service_type: f"Service {type} started as process (pid: {pid})"
        )

        self.multi_process_info.append(
            MultiProcessRunInfo(
                process=process,
                service_type=service_type,
                service_id=service_id,
            )
        )

shutdown_all_services() async

Stop all required services as multiprocessing processes.

Source code in aiperf/controller/multiprocess_service_manager.py
114
115
116
117
118
119
120
121
122
async def shutdown_all_services(self) -> list[BaseException | None]:
    """Stop all required services as multiprocessing processes."""
    self.debug("Stopping all service processes")

    # Wait for all to finish in parallel
    return await asyncio.gather(
        *[self._wait_for_process(info) for info in self.multi_process_info],
        return_exceptions=True,
    )

wait_for_all_services_registration(stop_event, timeout_seconds=DEFAULT_SERVICE_REGISTRATION_TIMEOUT) async

Wait for all required services to be registered.

Parameters:

Name Type Description Default
stop_event Event

Event to check if operation should be cancelled

required
timeout_seconds float

Maximum time to wait in seconds

DEFAULT_SERVICE_REGISTRATION_TIMEOUT
Source code in aiperf/controller/multiprocess_service_manager.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
async def wait_for_all_services_registration(
    self,
    stop_event: asyncio.Event,
    timeout_seconds: float = DEFAULT_SERVICE_REGISTRATION_TIMEOUT,
) -> None:
    """Wait for all required services to be registered.

    Args:
        stop_event: Event to check if operation should be cancelled
        timeout_seconds: Maximum time to wait in seconds

    Raises:
        Exception if any service failed to register, None otherwise
    """
    self.debug("Waiting for all required services to register...")

    # Get the set of required service types for checking completion
    required_types = set(self.required_services.keys())

    # TODO: Can this be done better by using asyncio.Event()?

    async def _wait_for_registration():
        while not stop_event.is_set():
            # Get all registered service types from the id map
            registered_types = {
                service_info.service_type
                for service_info in self.service_id_map.values()
                if service_info.registration_status
                == ServiceRegistrationStatus.REGISTERED
            }

            # Check if all required types are registered
            if required_types.issubset(registered_types):
                return

            # Wait a bit before checking again
            await asyncio.sleep(0.5)

    try:
        await asyncio.wait_for(_wait_for_registration(), timeout=timeout_seconds)
    except asyncio.TimeoutError as e:
        # Log which services didn't register in time
        registered_types_set = set(
            service_info.service_type
            for service_info in self.service_id_map.values()
            if service_info.registration_status
            == ServiceRegistrationStatus.REGISTERED
        )

        for service_type in required_types:
            if service_type not in registered_types_set:
                self.error(
                    f"Service {service_type} failed to register within timeout"
                )

        raise AIPerfError("Some services failed to register within timeout") from e

wait_for_all_services_start(stop_event, timeout_seconds=DEFAULT_SERVICE_START_TIMEOUT) async

Wait for all required services to be started.

Source code in aiperf/controller/multiprocess_service_manager.py
215
216
217
218
219
220
221
222
223
224
async def wait_for_all_services_start(
    self,
    stop_event: asyncio.Event,
    timeout_seconds: float = DEFAULT_SERVICE_START_TIMEOUT,
) -> None:
    """Wait for all required services to be started."""
    self.debug("Waiting for all required services to start...")
    self.warning(
        "Waiting for all required services to start is not implemented for multiprocessing"
    )

aiperf.controller.proxy_manager

aiperf.controller.system_controller

SystemController

Bases: SignalHandlerMixin, BaseService

System Controller service.

This service is responsible for managing the lifecycle of all other services. It will start, stop, and configure all other services.

Source code in aiperf/controller/system_controller.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
@ServiceFactory.register(ServiceType.SYSTEM_CONTROLLER)
class SystemController(SignalHandlerMixin, BaseService):
    """System Controller service.

    This service is responsible for managing the lifecycle of all other services.
    It will start, stop, and configure all other services.
    """

    def __init__(
        self,
        user_config: UserConfig,
        service_config: ServiceConfig,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
        )
        self.debug("Creating System Controller")
        self._was_cancelled = False
        # List of required service types, in no particular order
        # These are services that must be running before the system controller can start profiling
        self.required_services: dict[ServiceTypeT, int] = {
            ServiceType.DATASET_MANAGER: 1,
            ServiceType.TIMING_MANAGER: 1,
            ServiceType.WORKER_MANAGER: 1,
            ServiceType.RECORDS_MANAGER: 1,
        }
        if self.service_config.record_processor_service_count is not None:
            self.required_services[ServiceType.RECORD_PROCESSOR] = (
                self.service_config.record_processor_service_count
            )
            self.scale_record_processors_with_workers = False
        else:
            self.scale_record_processors_with_workers = True

        self.proxy_manager: ProxyManager = ProxyManager(
            service_config=self.service_config
        )
        self.service_manager: ServiceManagerProtocol = (
            ServiceManagerFactory.create_instance(
                self.service_config.service_run_type.value,
                required_services=self.required_services,
                user_config=self.user_config,
                service_config=self.service_config,
                log_queue=get_global_log_queue(),
            )
        )
        self._stop_tasks: set[asyncio.Task] = set()
        self.debug("System Controller created")

    async def initialize(self) -> None:
        """We need to override the initialize method to run the proxy manager before the base service initialize.
        This is because the proxies need to be running before we can subscribe to the message bus.
        """
        self.debug("Running ZMQ Proxy Manager Before Initialize")
        await self.proxy_manager.initialize_and_start()
        # Once the proxies are running, call the original initialize method
        await super().initialize()

    @on_init
    async def _initialize_system_controller(self) -> None:
        self.debug("Initializing System Controller")

        self.setup_signal_handlers(self._handle_signal)
        self.debug("Setup signal handlers")
        await self.service_manager.initialize()

    @on_start
    async def _start_services(self) -> None:
        """Bootstrap the system services.

        This method will:
        - Initialize all required services
        - Wait for all required services to be registered
        - Start all required services
        """
        self.debug("System Controller is bootstrapping services")
        # Start all required services
        await self.service_manager.start()
        await self.service_manager.wait_for_all_services_registration(
            stop_event=self._stop_requested_event,
        )

        self.info("AIPerf System is CONFIGURING")
        await self._profile_configure_all_services()
        self.info("AIPerf System is CONFIGURED")
        await self._start_profiling_all_services()
        self.info("AIPerf System is PROFILING")

    async def _profile_configure_all_services(self) -> None:
        """Configure all services to start profiling.

        This is a blocking call that will wait for all services to be configured before returning. This way
        we can ensure that all services are configured before we start profiling.
        """
        self.info("Configuring all services to start profiling")
        begin = time.perf_counter()
        await self.send_command_and_wait_for_all_responses(
            ProfileConfigureCommand(
                service_id=self.service_id,
                config=self.user_config,
            ),
            list(self.service_manager.service_id_map.keys()),
            timeout=DEFAULT_PROFILE_CONFIGURE_TIMEOUT,
        )
        duration = time.perf_counter() - begin
        self.info(f"All services configured in {duration:.2f} seconds")

    async def _start_profiling_all_services(self) -> None:
        """Tell all services to start profiling."""
        self.debug("Sending PROFILE_START command to all services")
        await self.send_command_and_wait_for_all_responses(
            ProfileStartCommand(
                service_id=self.service_id,
            ),
            list(self.service_manager.service_id_map.keys()),
            timeout=DEFAULT_PROFILE_START_TIMEOUT,
        )
        self.info("All services started profiling successfully")

    @on_command(CommandType.REGISTER_SERVICE)
    async def _handle_register_service_command(
        self, message: RegisterServiceCommand
    ) -> None:
        """Process a registration message from a service. It will
        add the service to the service manager and send a configure command
        to the service.

        Args:
            message: The registration message to process
        """

        self.debug(
            lambda: f"Processing registration from {message.service_type} with ID: {message.service_id}"
        )

        service_info = ServiceRunInfo(
            registration_status=ServiceRegistrationStatus.REGISTERED,
            service_type=message.service_type,
            service_id=message.service_id,
            first_seen=time.time_ns(),
            state=message.state,
            last_seen=time.time_ns(),
        )

        self.service_manager.service_id_map[message.service_id] = service_info
        if message.service_type not in self.service_manager.service_map:
            self.service_manager.service_map[message.service_type] = []
        self.service_manager.service_map[message.service_type].append(service_info)

        try:
            type_name = ServiceType(message.service_type).name.title().replace("_", " ")
        except (TypeError, ValueError):
            type_name = message.service_type
        self.info(lambda: f"Registered {type_name} (id: '{message.service_id}')")

    @on_message(MessageType.HEARTBEAT)
    async def _process_heartbeat_message(self, message: HeartbeatMessage) -> None:
        """Process a heartbeat message from a service. It will
        update the last seen timestamp and state of the service.

        Args:
            message: The heartbeat message to process
        """
        service_id = message.service_id
        service_type = message.service_type
        timestamp = message.request_ns

        self.debug(lambda: f"Received heartbeat from {service_type} (ID: {service_id})")

        # Update the last heartbeat timestamp if the component exists
        try:
            service_info = self.service_manager.service_id_map[service_id]
            service_info.last_seen = timestamp
            service_info.state = message.state
            self.debug(f"Updated heartbeat for {service_id} to {timestamp}")
        except Exception:
            self.warning(
                f"Received heartbeat from unknown service: {service_id} ({service_type})"
            )

    @on_message(MessageType.CREDITS_COMPLETE)
    async def _process_credits_complete_message(
        self, message: CreditsCompleteMessage
    ) -> None:
        """Process a credits complete message from a service. It will
        update the state of the service with the service manager.

        Args:
            message: The credits complete message to process
        """
        service_id = message.service_id
        self.info(f"Received credits complete from {service_id}")

    @on_message(MessageType.STATUS)
    async def _process_status_message(self, message: StatusMessage) -> None:
        """Process a status message from a service. It will
        update the state of the service with the service manager.

        Args:
            message: The status message to process
        """
        service_id = message.service_id
        service_type = message.service_type
        state = message.state

        self.debug(
            lambda: f"Received status update from {service_type} (ID: {service_id}): {state}"
        )

        # Update the component state if the component exists
        if service_id not in self.service_manager.service_id_map:
            self.debug(
                lambda: f"Received status update from un-registered service: {service_id} ({service_type})"
            )
            return

        service_info = self.service_manager.service_id_map.get(service_id)
        if service_info is None:
            return

        service_info.state = message.state

        self.debug(f"Updated state for {service_id} to {message.state}")

    @on_message(MessageType.NOTIFICATION)
    async def _process_notification_message(self, message: NotificationMessage) -> None:
        """Process a notification message."""
        self.info(f"Received notification message: {message}")

    @on_message(MessageType.COMMAND_RESPONSE)
    async def _process_command_response_message(self, message: CommandResponse) -> None:
        """Process a command response message."""
        self.debug(lambda: f"Received command response message: {message}")
        if message.status == CommandResponseStatus.SUCCESS:
            self.debug(f"Command {message.command} succeeded from {message.service_id}")
        elif message.status == CommandResponseStatus.ACKNOWLEDGED:
            self.debug(
                f"Command {message.command} acknowledged from {message.service_id}"
            )
        elif message.status == CommandResponseStatus.UNHANDLED:
            self.debug(f"Command {message.command} unhandled from {message.service_id}")
        elif message.status == CommandResponseStatus.FAILURE:
            message = cast(CommandErrorResponse, message)
            self.error(
                f"Command {message.command} failed from {message.service_id}: {message.error}"
            )

    @on_command(CommandType.SPAWN_WORKERS)
    async def _handle_spawn_workers_command(self, message: SpawnWorkersCommand) -> None:
        """Handle a spawn workers command."""
        self.debug(lambda: f"Received spawn workers command: {message}")
        # Spawn the workers
        await self.service_manager.run_service(ServiceType.WORKER, message.num_workers)
        # If we are scaling the record processor service count with the number of workers, spawn the record processors
        if self.scale_record_processors_with_workers:
            await self.service_manager.run_service(
                ServiceType.RECORD_PROCESSOR, message.num_workers
            )

    @on_command(CommandType.SHUTDOWN_WORKERS)
    async def _handle_shutdown_workers_command(
        self, message: ShutdownWorkersCommand
    ) -> None:
        """Handle a shutdown workers command."""
        self.debug(lambda: f"Received shutdown workers command: {message}")
        # TODO: Handle individual worker shutdowns via worker id
        await self.service_manager.stop_service(ServiceType.WORKER)
        if self.scale_record_processors_with_workers:
            await self.service_manager.stop_service(ServiceType.RECORD_PROCESSOR)

    @on_message(MessageType.PROCESS_RECORDS_RESULT)
    async def _on_process_records_result_message(
        self, message: ProcessRecordsResultMessage
    ) -> None:
        """Handle a profile results message."""
        self.debug(lambda: f"Received profile results message: {message}")
        if message.results.errors:
            self.error(
                f"Received process records result message with errors: {message.results.errors}"
            )

        # This will be displayed by the console error exporter
        self.debug(lambda: f"Error summary: {message.results.results.error_summary}")

        if message.results.results:
            await ExporterManager(
                results=message.results.results,
                input_config=self.user_config,
            ).export_all()
        else:
            self.error(
                f"Received process records result message with no records: {message.results.results}"
            )

        if self._was_cancelled:
            warn_cancelled_early()

        # TODO: HACK: Stop the system controller after exporting the records
        self.debug("Stopping system controller after exporting records")
        await asyncio.shield(self.stop())

    async def _handle_signal(self, sig: int) -> None:
        """Handle received signals by triggering graceful shutdown.

        Args:
            sig: The signal number received
        """
        if self.stop_requested:
            # If we are already in a stopping state, we need to kill the process to be safe.
            self.warning(lambda: f"Received signal {sig}, killing")
            await self._kill()
            return

        self.debug(lambda: f"Received signal {sig}, initiating graceful shutdown")
        await self._cancel_profiling()

    async def _cancel_profiling(self) -> None:
        self.debug("Cancelling profiling of all services")
        self._was_cancelled = True
        await self.publish(ProfileCancelCommand(service_id=self.service_id))

        # TODO: HACK: Wait for 2 seconds to ensure the profiling is cancelled
        # Wait for the profiling to be cancelled
        await asyncio.sleep(2)
        self.debug("Stopping system controller after profiling cancelled")
        await asyncio.shield(self.stop())

    @on_stop
    async def _stop_system_controller(self) -> None:
        """Stop the system controller and all running services."""
        # Broadcast a shutdown command to all services
        await self.publish(ShutdownCommand(service_id=self.service_id))

        # TODO: HACK: Wait for 0.5 seconds to ensure the shutdown command is received
        await asyncio.sleep(0.5)

        await self.service_manager.shutdown_all_services()
        await self.comms.stop()
        await self.proxy_manager.stop()

    async def _kill(self):
        """Kill the system controller."""
        try:
            await self.service_manager.kill_all_services()
        except Exception as e:
            raise self._service_error("Failed to stop all services") from e

        await super()._kill()

initialize() async

We need to override the initialize method to run the proxy manager before the base service initialize. This is because the proxies need to be running before we can subscribe to the message bus.

Source code in aiperf/controller/system_controller.py
105
106
107
108
109
110
111
112
async def initialize(self) -> None:
    """We need to override the initialize method to run the proxy manager before the base service initialize.
    This is because the proxies need to be running before we can subscribe to the message bus.
    """
    self.debug("Running ZMQ Proxy Manager Before Initialize")
    await self.proxy_manager.initialize_and_start()
    # Once the proxies are running, call the original initialize method
    await super().initialize()

main()

Main entry point for the system controller.

Source code in aiperf/controller/system_controller.py
406
407
408
409
410
411
def main() -> None:
    """Main entry point for the system controller."""

    from aiperf.common.bootstrap import bootstrap_and_run_service

    bootstrap_and_run_service(SystemController)

aiperf.controller.system_mixins

SignalHandlerMixin

Bases: AIPerfLoggerMixin

Mixin for services that need to handle system signals.

Source code in aiperf/controller/system_mixins.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class SignalHandlerMixin(AIPerfLoggerMixin):
    """Mixin for services that need to handle system signals."""

    def __init__(self, **kwargs) -> None:
        # Set to store signal handler tasks to prevent them from being garbage collected
        self._signal_tasks = set()
        super().__init__(**kwargs)

    def setup_signal_handlers(self, callback: Callable[[int], Coroutine]) -> None:
        """This method will set up signal handlers for the SIGTERM and SIGINT signals
        in order to trigger a graceful shutdown of the service.

        Args:
            callback: The callback to call when a signal is received
        """
        loop = asyncio.get_running_loop()

        def signal_handler(sig: int) -> None:
            task = asyncio.create_task(callback(sig))
            self._signal_tasks.add(task)
            task.add_done_callback(self._signal_tasks.discard)

        loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT)

setup_signal_handlers(callback)

This method will set up signal handlers for the SIGTERM and SIGINT signals in order to trigger a graceful shutdown of the service.

Parameters:

Name Type Description Default
callback Callable[[int], Coroutine]

The callback to call when a signal is received

required
Source code in aiperf/controller/system_mixins.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def setup_signal_handlers(self, callback: Callable[[int], Coroutine]) -> None:
    """This method will set up signal handlers for the SIGTERM and SIGINT signals
    in order to trigger a graceful shutdown of the service.

    Args:
        callback: The callback to call when a signal is received
    """
    loop = asyncio.get_running_loop()

    def signal_handler(sig: int) -> None:
        task = asyncio.create_task(callback(sig))
        self._signal_tasks.add(task)
        task.add_done_callback(self._signal_tasks.discard)

    loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT)

aiperf.dataset.composer.base

BaseDatasetComposer

Bases: AIPerfLoggerMixin, ABC

Source code in aiperf/dataset/composer/base.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class BaseDatasetComposer(AIPerfLoggerMixin, ABC):
    def __init__(self, config: UserConfig, tokenizer: Tokenizer, **kwargs):
        self.config = config
        super().__init__(config=config, tokenizer=tokenizer, **kwargs)
        self.prompt_generator = PromptGenerator(config.input.prompt, tokenizer)
        self.image_generator = ImageGenerator(config.input.image)
        self.audio_generator = AudioGenerator(config.input.audio)
        self.turn_count = 0

    @abstractmethod
    def create_dataset(self) -> list[Conversation]:
        """
        Create a set of conversation objects from the given configuration.

        Returns:
            list[Conversation]: A list of conversation objects.
        """
        ...

    def _select_model_name(self) -> str:
        if (
            self.config.endpoint.model_selection_strategy
            == ModelSelectionStrategy.RANDOM
        ):
            return random.choice(self.config.endpoint.model_names)
        elif (
            self.config.endpoint.model_selection_strategy
            == ModelSelectionStrategy.ROUND_ROBIN
        ):
            model_name = self.config.endpoint.model_names[
                self.turn_count % len(self.config.endpoint.model_names)
            ]
            self.turn_count += 1
            return model_name
        else:
            raise ValueError(
                f"Invalid model selection strategy: {self.config.endpoint.model_selection_strategy}."
            )

    @property
    def prefix_prompt_enabled(self) -> bool:
        return self.config.input.prompt.prefix_prompt.length > 0

create_dataset() abstractmethod

Create a set of conversation objects from the given configuration.

Returns:

Type Description
list[Conversation]

list[Conversation]: A list of conversation objects.

Source code in aiperf/dataset/composer/base.py
28
29
30
31
32
33
34
35
36
@abstractmethod
def create_dataset(self) -> list[Conversation]:
    """
    Create a set of conversation objects from the given configuration.

    Returns:
        list[Conversation]: A list of conversation objects.
    """
    ...

aiperf.dataset.composer.custom

CustomDatasetComposer

Bases: BaseDatasetComposer

Source code in aiperf/dataset/composer/custom.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@implements_protocol(ServiceProtocol)
@ComposerFactory.register(ComposerType.CUSTOM)
class CustomDatasetComposer(BaseDatasetComposer):
    def __init__(self, config: UserConfig, tokenizer: Tokenizer):
        super().__init__(config, tokenizer)

    def create_dataset(self) -> list[Conversation]:
        """Create conversations from a file or directory.

        Returns:
            list[Conversation]: A list of conversation objects.
        """
        # TODO: (future) for K8s, we need to transfer file data from SC (across node)
        utils.check_file_exists(self.config.input.file)

        self._create_loader_instance(self.config.input.custom_dataset_type)
        dataset = self.loader.load_dataset()
        conversations = self.loader.convert_to_conversations(dataset)
        self._add_model_names_to_conversations(conversations)
        return conversations

    def _create_loader_instance(self, dataset_type: CustomDatasetType) -> None:
        """Initializes the dataset loader based on the custom dataset type.

        Args:
            dataset_type: The type of custom dataset to create.
        """
        kwargs = {"filename": self.config.input.file}
        if dataset_type == CustomDatasetType.MOONCAKE_TRACE:
            kwargs["prompt_generator"] = self.prompt_generator
        elif dataset_type == CustomDatasetType.RANDOM_POOL:
            kwargs["num_conversations"] = self.config.input.conversation.num

        self.loader = CustomDatasetFactory.create_instance(dataset_type, **kwargs)

    def _add_model_names_to_conversations(
        self, conversations: list[Conversation]
    ) -> None:
        for conversation in conversations:
            for turn in conversation.turns:
                turn.model = self._select_model_name()

create_dataset()

Create conversations from a file or directory.

Returns:

Type Description
list[Conversation]

list[Conversation]: A list of conversation objects.

Source code in aiperf/dataset/composer/custom.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def create_dataset(self) -> list[Conversation]:
    """Create conversations from a file or directory.

    Returns:
        list[Conversation]: A list of conversation objects.
    """
    # TODO: (future) for K8s, we need to transfer file data from SC (across node)
    utils.check_file_exists(self.config.input.file)

    self._create_loader_instance(self.config.input.custom_dataset_type)
    dataset = self.loader.load_dataset()
    conversations = self.loader.convert_to_conversations(dataset)
    self._add_model_names_to_conversations(conversations)
    return conversations

aiperf.dataset.composer.synthetic

SyntheticDatasetComposer

Bases: BaseDatasetComposer

Source code in aiperf/dataset/composer/synthetic.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
@ComposerFactory.register(ComposerType.SYNTHETIC)
class SyntheticDatasetComposer(BaseDatasetComposer):
    def __init__(self, config: UserConfig, tokenizer: Tokenizer):
        super().__init__(config, tokenizer)

        if (
            not self.include_prompt
            and not self.include_image
            and not self.include_audio
        ):
            raise ValueError(
                "All synthetic data are disabled. "
                "Please enable at least one of prompt, image, or audio by "
                "setting the mean to a positive value."
            )

    def create_dataset(self) -> list[Conversation]:
        """Create a synthetic conversation dataset from the given configuration.

        It generates a set of conversations with a varying number of turns,
        where each turn contains synthetic text, image, and audio payloads.

        Returns:
            list[Conversation]: A list of conversation objects.
        """
        conversations = []
        for _ in range(self.config.input.conversation.num):
            conversation = Conversation(session_id=str(uuid.uuid4()))

            num_turns = utils.sample_positive_normal_integer(
                self.config.input.conversation.turn.mean,
                self.config.input.conversation.turn.stddev,
            )
            self.logger.debug("Creating conversation with %d turns", num_turns)

            for turn_idx in range(num_turns):
                turn = self._create_turn(is_first=(turn_idx == 0))
                conversation.turns.append(turn)
            conversations.append(conversation)
        return conversations

    def _create_turn(self, is_first: bool) -> Turn:
        """Create a turn object that contains synthetic payloads to send.

        It generates multi-modal data (e.g. text, image, audio) using synthetic
        generators and also the delay between turns.

        Args:
            is_first: Whether the turn is the first turn in the conversation.

        Returns:
            Turn: A dataset representation of a single turn.
        """
        turn = Turn()

        if self.include_prompt:
            turn.texts.append(self._generate_text_payloads(is_first))
        if self.include_image:
            turn.images.append(self._generate_image_payloads())
        if self.include_audio:
            turn.audios.append(self._generate_audio_payloads())

        # Add randomized delays between each turn. Skip if first turn.
        if not is_first:
            turn.delay = utils.sample_positive_normal_integer(
                self.config.input.conversation.turn.delay.mean,
                self.config.input.conversation.turn.delay.stddev,
            )

        if not turn.texts and not turn.images and not turn.audios:
            self.logger.warning(
                "There were no synthetic payloads generated. "
                "Please enable at least one of prompt, image, or audio by "
                "setting the mean to a positive value."
            )

        turn.model = self._select_model_name()

        return turn

    def _generate_text_payloads(self, is_first: bool) -> Text:
        """Generate synthetic text payloads.

        If the turn is the first turn in the conversation, it could add a prefix prompt
        to the prompt.

        Args:
            is_first: Whether the turn is the first turn in the conversation.

        Returns:
            Text: A text payload object.
        """
        text = Text(name="text")
        for _ in range(self.config.input.prompt.batch_size):
            prompt = self.prompt_generator.generate(
                mean=self.config.input.prompt.input_tokens.mean,
                stddev=self.config.input.prompt.input_tokens.stddev,
            )

            if self.prefix_prompt_enabled and is_first:
                # TODO: Rename
                prefix_prompt = self.prompt_generator.get_random_prefix_prompt()
                prompt = f"{prefix_prompt} {prompt}"

            text.contents.append(prompt)
        return text

    def _generate_image_payloads(self) -> Image:
        """
        Generate synthetic images if the image width and height are specified.

        Returns:
            Image: An image payload object.
        """
        image = Image(name="image_url")
        for _ in range(self.config.input.image.batch_size):
            data = self.image_generator.generate()
            image.contents.append(data)
        return image

    def _generate_audio_payloads(self) -> Audio:
        """
        Generate synthetic audios if the audio length is specified.

        Returns:
            Audio: An audio payload object.
        """
        audio = Audio(name="input_audio")
        for _ in range(self.config.input.audio.batch_size):
            data = self.audio_generator.generate()
            audio.contents.append(data)
        return audio

    @property
    def include_prompt(self) -> bool:
        return self.config.input.prompt.input_tokens.mean > 0

    @property
    def include_image(self) -> bool:
        return (
            self.config.input.image.width.mean > 0
            and self.config.input.image.height.mean > 0
        )

    @property
    def include_audio(self) -> bool:
        return self.config.input.audio.length.mean > 0

create_dataset()

Create a synthetic conversation dataset from the given configuration.

It generates a set of conversations with a varying number of turns, where each turn contains synthetic text, image, and audio payloads.

Returns:

Type Description
list[Conversation]

list[Conversation]: A list of conversation objects.

Source code in aiperf/dataset/composer/synthetic.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def create_dataset(self) -> list[Conversation]:
    """Create a synthetic conversation dataset from the given configuration.

    It generates a set of conversations with a varying number of turns,
    where each turn contains synthetic text, image, and audio payloads.

    Returns:
        list[Conversation]: A list of conversation objects.
    """
    conversations = []
    for _ in range(self.config.input.conversation.num):
        conversation = Conversation(session_id=str(uuid.uuid4()))

        num_turns = utils.sample_positive_normal_integer(
            self.config.input.conversation.turn.mean,
            self.config.input.conversation.turn.stddev,
        )
        self.logger.debug("Creating conversation with %d turns", num_turns)

        for turn_idx in range(num_turns):
            turn = self._create_turn(is_first=(turn_idx == 0))
            conversation.turns.append(turn)
        conversations.append(conversation)
    return conversations

aiperf.dataset.dataset_manager

DatasetManager

Bases: ReplyClientMixin, BaseComponentService

The DatasetManager primary responsibility is to manage the data generation or acquisition. For synthetic generation, it contains the code to generate the prompts or tokens. It will have an API for dataset acquisition of a dataset if available in a remote repository or database.

Source code in aiperf/dataset/dataset_manager.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
@implements_protocol(ServiceProtocol)
@ServiceFactory.register(ServiceType.DATASET_MANAGER)
class DatasetManager(ReplyClientMixin, BaseComponentService):
    """
    The DatasetManager primary responsibility is to manage the data generation or acquisition.
    For synthetic generation, it contains the code to generate the prompts or tokens.
    It will have an API for dataset acquisition of a dataset if available in a remote repository or database.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            reply_client_address=CommAddress.DATASET_MANAGER_PROXY_BACKEND,
            reply_client_bind=False,
        )
        self.debug("Dataset manager __init__")
        self.user_config = user_config
        self.tokenizer: Tokenizer | None = None
        self.dataset: dict[str, Conversation] = {}  # session ID -> Conversation mapping
        self._session_ids_cache: list[str] = []
        self._conversation_query_random = random.Random(
            self.user_config.input.random_seed
        )
        self.dataset_configured = asyncio.Event()

    @on_command(CommandType.PROFILE_CONFIGURE)
    async def _profile_configure_command(
        self, message: ProfileConfigureCommand
    ) -> None:
        """Configure the dataset."""
        self.info(lambda: f"Configuring dataset for {self.service_id}")
        begin = time.perf_counter()
        await self._configure_dataset()
        duration = time.perf_counter() - begin
        self.info(lambda: f"Dataset configured in {duration:.2f} seconds")

    async def _configure_dataset(self) -> None:
        if self.user_config is None:
            raise self._service_error("User config is required for dataset manager")

        self.dataset_configured.clear()
        if self.user_config.input.file:
            composer_type = ComposerType.CUSTOM
            self.debug(
                lambda: f"Detected input file '{self.user_config.input.file}'. Setting the composer type to {ComposerType.CUSTOM}."
            )
        else:
            composer_type = ComposerType.SYNTHETIC
            self.debug(
                lambda: f"No input file detected. Setting the composer type to {ComposerType.SYNTHETIC}."
            )

        tokenizer_name = self.user_config.tokenizer.name
        if tokenizer_name is None:
            # TODO: What do we do if there are multiple models?
            # How will we know which tokenizer to use?
            tokenizer_name = self.user_config.endpoint.model_names[0]

        tokenizer = Tokenizer.from_pretrained(
            tokenizer_name,
            trust_remote_code=self.user_config.tokenizer.trust_remote_code,
            revision=self.user_config.tokenizer.revision,
        )
        composer = ComposerFactory.create_instance(
            composer_type,
            config=self.user_config,
            tokenizer=tokenizer,
        )
        conversations = composer.create_dataset()
        self.dataset = {conv.session_id: conv for conv in conversations}
        self._session_ids_cache = list(self.dataset.keys())

        self.dataset_configured.set()
        await self.publish(
            DatasetConfiguredNotification(
                service_id=self.service_id,
            ),
        )

    @on_request(MessageType.CONVERSATION_REQUEST)
    async def _handle_conversation_request(
        self, message: ConversationRequestMessage
    ) -> ConversationResponseMessage:
        """Handle a conversation request."""
        self.debug(lambda: f"Handling conversation request: {message}")

        await self._wait_for_dataset_configuration()

        if not self.dataset:
            raise self._service_error(
                "Dataset is empty and must be configured before handling requests.",
            )

        if message.conversation_id is None:
            return self._return_any_conversation(
                request_id=message.request_id,
            )
        else:
            return self._return_conversation_by_id(
                request_id=message.request_id,
                conversation_id=message.conversation_id,
            )

    def _return_any_conversation(
        self, request_id: str | None
    ) -> ConversationResponseMessage:
        """Return any conversation from the dataset based on the user specified method."""

        # TODO: Implement the user specified method (random, round robin, etc.)
        session_id = self._conversation_query_random.choice(self._session_ids_cache)
        conversation = self.dataset[session_id]
        self.trace_or_debug(
            lambda: f"Sending random conversation response: {conversation}",
            lambda: f"Sending random conversation response with id: {conversation.session_id}",
        )
        return ConversationResponseMessage(
            service_id=self.service_id,
            request_id=request_id,
            conversation=conversation,
        )

    def _return_conversation_by_id(
        self, request_id: str | None, conversation_id: str
    ) -> ConversationResponseMessage:
        """Return a conversation if it exists, otherwise raise an error."""

        if conversation_id not in self.dataset:
            raise self._service_error(
                f"Conversation {conversation_id} not found in dataset.",
            )

        conversation = self.dataset[conversation_id]
        self.trace_or_debug(
            lambda: f"Sending conversation response: {conversation}",
            lambda: f"Sending conversation response with id: {conversation.session_id}",
        )
        return ConversationResponseMessage(
            service_id=self.service_id,
            request_id=request_id,
            conversation=conversation,
        )

    @on_request(MessageType.CONVERSATION_TURN_REQUEST)
    async def _handle_conversation_turn_request(
        self, message: ConversationTurnRequestMessage
    ) -> ConversationTurnResponseMessage:
        """Handle a turn request."""
        self.debug(lambda: f"Handling turn request: {message}")

        if message.conversation_id not in self.dataset:
            raise self._service_error(
                f"Conversation {message.conversation_id} not found in dataset.",
            )

        conversation = self.dataset[message.conversation_id]
        if message.turn_index >= len(conversation.turns):
            raise self._service_error(
                f"Turn index {message.turn_index} is out of range for conversation {message.conversation_id}.",
            )

        turn = conversation.turns[message.turn_index]

        self.trace_or_debug(
            lambda: f"Sending turn response: {turn}",
            "Sending turn response",
        )
        return ConversationTurnResponseMessage(
            service_id=self.service_id,
            request_id=message.request_id,
            turn=turn,
        )

    @on_request(MessageType.DATASET_TIMING_REQUEST)
    async def _handle_dataset_timing_request(
        self, message: DatasetTimingRequest
    ) -> DatasetTimingResponse:
        """Handle a dataset timing request."""
        self.trace_or_debug(
            lambda: f"Handling dataset timing request: {message}",
            "Handling dataset timing request",
        )

        await self._wait_for_dataset_configuration()

        if not self.dataset:
            raise self._service_error(
                "Dataset is empty and must be configured before handling timing requests.",
            )

        timing_dataset = []
        for conversation_id, conversation in self.dataset.items():
            for turn in conversation.turns:
                timing_dataset.append((turn.timestamp, conversation_id))

        return DatasetTimingResponse(
            service_id=self.service_id,
            request_id=message.request_id,
            timing_data=timing_dataset,
        )

    async def _wait_for_dataset_configuration(self) -> None:
        """Wait for the dataset to be configured if it is not already."""
        if not self.dataset_configured.is_set():
            self.debug(
                "Dataset not configured. Waiting for dataset to be configured..."
            )
            await asyncio.wait_for(
                self.dataset_configured.wait(), timeout=DATASET_CONFIGURATION_TIMEOUT
            )

main()

Main entry point for the dataset manager.

Source code in aiperf/dataset/dataset_manager.py
250
251
252
253
254
255
def main() -> None:
    """Main entry point for the dataset manager."""

    from aiperf.common.bootstrap import bootstrap_and_run_service

    bootstrap_and_run_service(DatasetManager)

aiperf.dataset.generator.audio

AudioGenerator

Bases: BaseGenerator

A class for generating synthetic audio data.

This class provides methods to create audio samples with specified characteristics such as format (WAV, MP3), length, sampling rate, bit depth, and number of channels. It supports validation of audio parameters to ensure compatibility with chosen formats.

Source code in aiperf/dataset/generator/audio.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class AudioGenerator(BaseGenerator):
    """
    A class for generating synthetic audio data.

    This class provides methods to create audio samples with specified
    characteristics such as format (WAV, MP3), length, sampling rate,
    bit depth, and number of channels. It supports validation of audio
    parameters to ensure compatibility with chosen formats.
    """

    def __init__(self, config: AudioConfig):
        super().__init__()
        self.config = config

    def _validate_sampling_rate(
        self, sampling_rate_hz: int, audio_format: AudioFormat
    ) -> None:
        """
        Validate sampling rate for the given output format.

        Args:
            sampling_rate_hz: Sampling rate in Hz
            audio_format: Audio format

        Raises:
            ConfigurationError: If sampling rate is not supported for the given format
        """
        if (
            audio_format == AudioFormat.MP3
            and sampling_rate_hz not in MP3_SUPPORTED_SAMPLE_RATES
        ):
            supported_rates = sorted(MP3_SUPPORTED_SAMPLE_RATES)
            raise ConfigurationError(
                f"MP3 format only supports the following sample rates (in Hz): {supported_rates}. "
                f"Got {sampling_rate_hz} Hz. Please choose a supported rate from the list."
            )

    def _validate_bit_depth(self, bit_depth: int) -> None:
        """
        Validate bit depth is supported.

        Args:
            bit_depth: Bit depth in bits

        Raises:
            ConfigurationError: If bit depth is not supported
        """
        if bit_depth not in SUPPORTED_BIT_DEPTHS:
            supported_depths = sorted(SUPPORTED_BIT_DEPTHS.keys())
            raise ConfigurationError(
                f"Unsupported bit depth: {bit_depth}. "
                f"Supported bit depths are: {supported_depths}"
            )

    def generate(self, *args, **kwargs) -> str:
        """Generate audio data with specified parameters.

        Returns:
            Data URI containing base64-encoded audio data with format specification

        Raises:
            ConfigurationError: If any of the following conditions are met:
                - audio length is less than 0.01 seconds
                - channels is not 1 (mono) or 2 (stereo)
                - sampling rate is not supported for MP3 format
                - bit depth is not supported (must be 8, 16, 24, or 32)
                - audio format is not supported (must be 'wav' or 'mp3')
        """
        if self.config.num_channels not in (1, 2):
            raise ConfigurationError(
                "Only mono (1) and stereo (2) channels are supported"
            )

        if self.config.length.mean < 0.01:
            raise ConfigurationError("Audio length must be greater than 0.01 seconds")

        # Sample audio length (in seconds) using rejection sampling
        audio_length = utils.sample_normal(
            self.config.length.mean, self.config.length.stddev, lower=0.01
        )

        # Randomly select sampling rate and bit depth
        sampling_rate_hz = int(
            np.random.choice(self.config.sample_rates) * 1000
        )  # Convert kHz to Hz
        bit_depth = np.random.choice(self.config.depths)

        # Validate sampling rate and bit depth
        self._validate_sampling_rate(sampling_rate_hz, self.config.format)
        self._validate_bit_depth(bit_depth)

        # Generate synthetic audio data (gaussian noise)
        num_samples = int(audio_length * sampling_rate_hz)
        audio_data = np.random.normal(
            0,
            0.3,
            (
                (num_samples, self.config.num_channels)
                if self.config.num_channels > 1
                else num_samples
            ),
        )

        # Ensure the signal is within [-1, 1] range
        audio_data = np.clip(audio_data, -1, 1)

        # Scale to the appropriate bit depth range
        max_val = 2 ** (bit_depth - 1) - 1
        numpy_type, _ = SUPPORTED_BIT_DEPTHS[bit_depth]
        audio_data = (audio_data * max_val).astype(numpy_type)

        # Write audio using soundfile
        output_buffer = io.BytesIO()

        # Select appropriate subtype based on format
        if self.config.format == AudioFormat.MP3:
            subtype = "MPEG_LAYER_III"
        elif self.config.format == AudioFormat.WAV:
            _, subtype = SUPPORTED_BIT_DEPTHS[bit_depth]
        else:
            raise ConfigurationError(
                f"Unsupported audio format: {self.config.format}. "
                f"Supported formats are: {AudioFormat.WAV.name}, {AudioFormat.MP3.name}"
            )

        sf.write(
            output_buffer,
            audio_data,
            sampling_rate_hz,
            format=self.config.format,
            subtype=subtype,
        )
        audio_bytes = output_buffer.getvalue()

        # Encode to base64 with data URI scheme: "{format},{data}"
        base64_data = base64.b64encode(audio_bytes).decode("utf-8")
        return f"{self.config.format.lower()},{base64_data}"

generate(*args, **kwargs)

Generate audio data with specified parameters.

Returns:

Type Description
str

Data URI containing base64-encoded audio data with format specification

Raises:

Type Description
ConfigurationError

If any of the following conditions are met: - audio length is less than 0.01 seconds - channels is not 1 (mono) or 2 (stereo) - sampling rate is not supported for MP3 format - bit depth is not supported (must be 8, 16, 24, or 32) - audio format is not supported (must be 'wav' or 'mp3')

Source code in aiperf/dataset/generator/audio.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def generate(self, *args, **kwargs) -> str:
    """Generate audio data with specified parameters.

    Returns:
        Data URI containing base64-encoded audio data with format specification

    Raises:
        ConfigurationError: If any of the following conditions are met:
            - audio length is less than 0.01 seconds
            - channels is not 1 (mono) or 2 (stereo)
            - sampling rate is not supported for MP3 format
            - bit depth is not supported (must be 8, 16, 24, or 32)
            - audio format is not supported (must be 'wav' or 'mp3')
    """
    if self.config.num_channels not in (1, 2):
        raise ConfigurationError(
            "Only mono (1) and stereo (2) channels are supported"
        )

    if self.config.length.mean < 0.01:
        raise ConfigurationError("Audio length must be greater than 0.01 seconds")

    # Sample audio length (in seconds) using rejection sampling
    audio_length = utils.sample_normal(
        self.config.length.mean, self.config.length.stddev, lower=0.01
    )

    # Randomly select sampling rate and bit depth
    sampling_rate_hz = int(
        np.random.choice(self.config.sample_rates) * 1000
    )  # Convert kHz to Hz
    bit_depth = np.random.choice(self.config.depths)

    # Validate sampling rate and bit depth
    self._validate_sampling_rate(sampling_rate_hz, self.config.format)
    self._validate_bit_depth(bit_depth)

    # Generate synthetic audio data (gaussian noise)
    num_samples = int(audio_length * sampling_rate_hz)
    audio_data = np.random.normal(
        0,
        0.3,
        (
            (num_samples, self.config.num_channels)
            if self.config.num_channels > 1
            else num_samples
        ),
    )

    # Ensure the signal is within [-1, 1] range
    audio_data = np.clip(audio_data, -1, 1)

    # Scale to the appropriate bit depth range
    max_val = 2 ** (bit_depth - 1) - 1
    numpy_type, _ = SUPPORTED_BIT_DEPTHS[bit_depth]
    audio_data = (audio_data * max_val).astype(numpy_type)

    # Write audio using soundfile
    output_buffer = io.BytesIO()

    # Select appropriate subtype based on format
    if self.config.format == AudioFormat.MP3:
        subtype = "MPEG_LAYER_III"
    elif self.config.format == AudioFormat.WAV:
        _, subtype = SUPPORTED_BIT_DEPTHS[bit_depth]
    else:
        raise ConfigurationError(
            f"Unsupported audio format: {self.config.format}. "
            f"Supported formats are: {AudioFormat.WAV.name}, {AudioFormat.MP3.name}"
        )

    sf.write(
        output_buffer,
        audio_data,
        sampling_rate_hz,
        format=self.config.format,
        subtype=subtype,
    )
    audio_bytes = output_buffer.getvalue()

    # Encode to base64 with data URI scheme: "{format},{data}"
    base64_data = base64.b64encode(audio_bytes).decode("utf-8")
    return f"{self.config.format.lower()},{base64_data}"

aiperf.dataset.generator.base

BaseGenerator

Bases: AIPerfLoggerMixin, ABC

Abstract base class for all data generators.

Provides a consistent interface for generating synthetic data while allowing each generator type to use its own specific configuration and runtime parameters.

Source code in aiperf/dataset/generator/base.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class BaseGenerator(AIPerfLoggerMixin, ABC):
    """Abstract base class for all data generators.

    Provides a consistent interface for generating synthetic data while allowing
    each generator type to use its own specific configuration and runtime parameters.
    """

    @abstractmethod
    def generate(self, *args, **kwargs) -> str:
        """Generate synthetic data.

        Args:
            *args: Variable length argument list (subclass-specific)
            **kwargs: Arbitrary keyword arguments (subclass-specific)

        Returns:
            Generated data as a string (could be text, base64 encoded media, etc.)
        """
        pass

generate(*args, **kwargs) abstractmethod

Generate synthetic data.

Parameters:

Name Type Description Default
*args

Variable length argument list (subclass-specific)

()
**kwargs

Arbitrary keyword arguments (subclass-specific)

{}

Returns:

Type Description
str

Generated data as a string (could be text, base64 encoded media, etc.)

Source code in aiperf/dataset/generator/base.py
16
17
18
19
20
21
22
23
24
25
26
27
@abstractmethod
def generate(self, *args, **kwargs) -> str:
    """Generate synthetic data.

    Args:
        *args: Variable length argument list (subclass-specific)
        **kwargs: Arbitrary keyword arguments (subclass-specific)

    Returns:
        Generated data as a string (could be text, base64 encoded media, etc.)
    """
    pass

aiperf.dataset.generator.image

ImageGenerator

Bases: BaseGenerator

A class that generates images from source images.

This class provides methods to create synthetic images by resizing source images (located in the 'assets/source_images' directory) to specified dimensions and converting them to a chosen image format (e.g., PNG, JPEG). The dimensions can be randomized based on mean and standard deviation values.

Source code in aiperf/dataset/generator/image.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class ImageGenerator(BaseGenerator):
    """A class that generates images from source images.

    This class provides methods to create synthetic images by resizing
    source images (located in the 'assets/source_images' directory)
    to specified dimensions and converting them to a chosen image format (e.g., PNG, JPEG).
    The dimensions can be randomized based on mean and standard deviation values.
    """

    def __init__(self, config: ImageConfig):
        super().__init__()
        self.config = config

    def generate(self, *args, **kwargs) -> str:
        """Generate an image with the configured parameters.

        Returns:
            A base64 encoded string of the generated image.
        """
        image_format = self.config.format
        if image_format == ImageFormat.RANDOM:
            image_format = random.choice(
                [f for f in ImageFormat if f != ImageFormat.RANDOM]
            )

        width = utils.sample_positive_normal_integer(
            self.config.width.mean, self.config.width.stddev
        )
        height = utils.sample_positive_normal_integer(
            self.config.height.mean, self.config.height.stddev
        )

        self.logger.debug(
            "Generating image with width=%d, height=%d",
            width,
            height,
        )

        image = self._sample_source_image()
        image = image.resize(size=(width, height))
        base64_image = utils.encode_image(image, image_format)
        return f"data:image/{image_format.name.lower()};base64,{base64_image}"

    def _sample_source_image(self):
        """Sample one image among the source images.

        Returns:
            A PIL Image object randomly selected from the source images.
        """
        filepath = Path(__file__).parent.resolve() / "assets" / "source_images" / "*"
        filenames = glob.glob(str(filepath))
        if not filenames:
            raise ValueError(f"No source images found in '{filepath}'")
        return Image.open(random.choice(filenames))

generate(*args, **kwargs)

Generate an image with the configured parameters.

Returns:

Type Description
str

A base64 encoded string of the generated image.

Source code in aiperf/dataset/generator/image.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def generate(self, *args, **kwargs) -> str:
    """Generate an image with the configured parameters.

    Returns:
        A base64 encoded string of the generated image.
    """
    image_format = self.config.format
    if image_format == ImageFormat.RANDOM:
        image_format = random.choice(
            [f for f in ImageFormat if f != ImageFormat.RANDOM]
        )

    width = utils.sample_positive_normal_integer(
        self.config.width.mean, self.config.width.stddev
    )
    height = utils.sample_positive_normal_integer(
        self.config.height.mean, self.config.height.stddev
    )

    self.logger.debug(
        "Generating image with width=%d, height=%d",
        width,
        height,
    )

    image = self._sample_source_image()
    image = image.resize(size=(width, height))
    base64_image = utils.encode_image(image, image_format)
    return f"data:image/{image_format.name.lower()};base64,{base64_image}"

aiperf.dataset.generator.prompt

PromptGenerator

Bases: BaseGenerator

A class for generating synthetic prompts from a text corpus.

This class loads a text corpus (e.g., Shakespearean text), tokenizes it, and uses the tokenized corpus to generate synthetic prompts of specified lengths. It supports generating prompts with a target number of tokens (with optional randomization around a mean and standard deviation) and can reuse previously generated token blocks to optimize generation for certain use cases. It also allows for the creation of a pool of prefix prompts that can be randomly selected.

Source code in aiperf/dataset/generator/prompt.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class PromptGenerator(BaseGenerator):
    """A class for generating synthetic prompts from a text corpus.

    This class loads a text corpus (e.g., Shakespearean text), tokenizes it,
    and uses the tokenized corpus to generate synthetic prompts of specified
    lengths. It supports generating prompts with a target number of tokens
    (with optional randomization around a mean and standard deviation) and
    can reuse previously generated token blocks to optimize generation for
    certain use cases. It also allows for the creation of a pool of prefix
    prompts that can be randomly selected.
    """

    def __init__(self, config: PromptConfig, tokenizer: Tokenizer, **kwargs):
        self.config = config
        self.tokenizer = tokenizer
        self._tokenized_corpus = None
        self._corpus_size = 0
        self._prefix_prompts: list[str] = []
        super().__init__(config=config, tokenizer=tokenizer, **kwargs)

        # Cached prompts: block ID -> list of tokens
        self._cache: dict[int, list[int]] = {}

        # TODO: move this under initialize() method
        # Initialize corpus if not already done
        if self._tokenized_corpus is None:
            self._initialize_corpus()

        # Initialize prefix prompts pool if the pool size > 0
        if self.config.prefix_prompt.pool_size > 0:
            self._create_prefix_prompt_pool()

    def _initialize_corpus(self) -> None:
        """Load and tokenize the corpus once, storing it for reuse."""
        corpus_path = Path(__file__).parent / DEFAULT_CORPUS_FILE

        with open(corpus_path) as f:
            lines = f.readlines()

        def tokenize_chunk(chunk):
            cleaned_text = " ".join(line.strip() for line in chunk if line.strip())
            tokens = self.tokenizer.encode(cleaned_text)
            return tokens

        num_threads = os.cpu_count()
        if num_threads is None:
            num_threads = 4

        # Ensure chunk_size is at least 1 to avoid division by zero in range()
        chunk_size = max(1, len(lines) // num_threads)
        chunks = [lines[i : i + chunk_size] for i in range(0, len(lines), chunk_size)]

        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            tokenized_chunks = list(executor.map(tokenize_chunk, chunks))

        self._tokenized_corpus = [
            token for chunk in tokenized_chunks for token in chunk
        ]
        self._corpus_size = len(self._tokenized_corpus)
        self.debug(lambda: f"Initialized corpus with {self._corpus_size} tokens")

    def _create_prefix_prompt_pool(self) -> None:
        """Generate a pool of prefix prompts to sample from."""
        if self._tokenized_corpus is None:
            raise NotInitializedError("Tokenized corpus is not initialized.")

        self._prefix_prompts = [
            self._generate_prompt(self.config.prefix_prompt.length)
            for _ in range(self.config.prefix_prompt.pool_size)
        ]
        self.debug(
            lambda: f"Initialized prefix prompts pool with {len(self._prefix_prompts)} prompts"
        )

    def generate(
        self,
        mean: int | None = None,
        stddev: int | None = None,
        hash_ids: list[int] | None = None,
    ) -> str:
        """Generate a synthetic prompt with the configuration parameters.

        Args:
            mean: The mean of the normal distribution.
            stddev: The standard deviation of the normal distribution.
            hash_ids: A list of hash indices used for token reuse.

        Returns:
            A synthetic prompt as a string.
        """
        if hash_ids:
            return self._generate_cached_prompt(
                mean, hash_ids, self.config.input_tokens.block_size
            )

        num_tokens = utils.sample_positive_normal_integer(mean, stddev)
        return self._generate_prompt(num_tokens)

    def _generate_prompt(self, num_tokens: int) -> str:
        """Generate a prompt containing exactly `num_tokens` number of tokens.

        Args:
            num_tokens: Number of tokens required in the prompt.

        Returns:
            A synthetic prompt as a string.
        """
        return self.tokenizer.decode(self._sample_tokens(num_tokens))

    def _generate_cached_prompt(
        self,
        num_tokens: int,
        hash_ids: list[int],
        block_size: int,
    ) -> str:
        """
        Generate a prompt containing exactly `num_tokens` by reusing previously generated prompts
        stored in `_cache`. Each hash index in `hash_ids` corresponds to a block of
        `block_size` tokens. If a hash index is found in `_cache`, its stored prompt is reused.
        Otherwise, a new prompt is generated using `_generate_prompt()` and stored in `_cache`.

        Args:
            num_tokens: The number of tokens required in the prompt.
            hash_ids: A list of hash IDs to use for token reuse.
            block_size: The number of tokens allocated per hash block.

        Returns:
            str: A synthetic prompt as a string.

        Raises:
            ConfigurationError: If the input parameters are not compatible.
        """
        final_prompt: list[int] = []
        current_block_size = block_size

        # Sanity check the final block size
        final_block_size = num_tokens - ((len(hash_ids) - 1) * block_size)
        if final_block_size <= 0 or block_size < final_block_size:
            raise ConfigurationError(
                f"Input length: {num_tokens}, Hash IDs: {hash_ids}, Block size: {block_size} "
                f"are not compatible. The final hash block size: {final_block_size} must be "
                f"greater than 0 and less than or equal to {block_size}."
            )

        for index, hash_id in enumerate(hash_ids):
            # For the last hash ID, use the remaining tokens as the block size
            if index == len(hash_ids) - 1:
                current_block_size = final_block_size

            if hash_id not in self._cache:
                # To ensure that the prompt doesn't merge chunks, we pop the last token
                # and insert the bos token at the beginning. Length is maintained and
                # the prompt generates the expected number of tokens.
                prompt_tokens: list[int] = self._sample_tokens(current_block_size)
                prompt_tokens.pop(0)
                prompt_tokens.insert(0, self.tokenizer.bos_token_id)
                self._cache[hash_id] = prompt_tokens  # store to cache

            final_prompt.extend(self._cache[hash_id])

        return self.tokenizer.decode(final_prompt, skip_special_tokens=False)

    def _sample_tokens(self, num_tokens: int) -> list[int]:
        """Generate a list of token IDs containing exactly `num_tokens` number of tokens
        using the preloaded tokenized corpus.

        Args:
            num_tokens: Number of tokens required in the prompt.

        Returns:
            A list of token IDs.

        Raises:
            NotInitializedError: If the tokenized corpus is not initialized
        """
        if not self._tokenized_corpus:
            raise NotInitializedError("Tokenized corpus is not initialized.")
        if num_tokens > self._corpus_size:
            self.warning(
                f"Requested prompt length {num_tokens} is longer than the corpus. "
                f"Returning a prompt of length {self._corpus_size}."
            )

        start_idx = random.randrange(self._corpus_size)

        end_idx = start_idx + num_tokens
        prompt_tokens = self._tokenized_corpus[start_idx:end_idx]
        if end_idx > self._corpus_size:
            prompt_tokens += self._tokenized_corpus[: end_idx - self._corpus_size]

        self.trace(lambda: f"Sampled {len(prompt_tokens)} tokens from corpus")
        return prompt_tokens

    def get_random_prefix_prompt(self) -> str:
        """
        Fetch a random prefix prompt from the pool.

        Returns:
            A random prefix prompt.

        Raises:
            InvalidStateError: If the prefix prompts pool is empty.
        """
        if not self._prefix_prompts:
            raise InvalidStateError(
                "Attempted to sample a prefix prompt but the prefix prompts pool is empty. "
                "Please ensure that the prefix prompts pool is initialized."
            )
        return random.choice(self._prefix_prompts)

generate(mean=None, stddev=None, hash_ids=None)

Generate a synthetic prompt with the configuration parameters.

Parameters:

Name Type Description Default
mean int | None

The mean of the normal distribution.

None
stddev int | None

The standard deviation of the normal distribution.

None
hash_ids list[int] | None

A list of hash indices used for token reuse.

None

Returns:

Type Description
str

A synthetic prompt as a string.

Source code in aiperf/dataset/generator/prompt.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def generate(
    self,
    mean: int | None = None,
    stddev: int | None = None,
    hash_ids: list[int] | None = None,
) -> str:
    """Generate a synthetic prompt with the configuration parameters.

    Args:
        mean: The mean of the normal distribution.
        stddev: The standard deviation of the normal distribution.
        hash_ids: A list of hash indices used for token reuse.

    Returns:
        A synthetic prompt as a string.
    """
    if hash_ids:
        return self._generate_cached_prompt(
            mean, hash_ids, self.config.input_tokens.block_size
        )

    num_tokens = utils.sample_positive_normal_integer(mean, stddev)
    return self._generate_prompt(num_tokens)

get_random_prefix_prompt()

Fetch a random prefix prompt from the pool.

Returns:

Type Description
str

A random prefix prompt.

Raises:

Type Description
InvalidStateError

If the prefix prompts pool is empty.

Source code in aiperf/dataset/generator/prompt.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def get_random_prefix_prompt(self) -> str:
    """
    Fetch a random prefix prompt from the pool.

    Returns:
        A random prefix prompt.

    Raises:
        InvalidStateError: If the prefix prompts pool is empty.
    """
    if not self._prefix_prompts:
        raise InvalidStateError(
            "Attempted to sample a prefix prompt but the prefix prompts pool is empty. "
            "Please ensure that the prefix prompts pool is initialized."
        )
    return random.choice(self._prefix_prompts)

aiperf.dataset.loader.mixins

MediaConversionMixin

Mixin providing shared media conversion functionality for dataset loaders. It is used to construct text, image, and audio data from a CustomDatasetT object.

Source code in aiperf/dataset/loader/mixins.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class MediaConversionMixin:
    """Mixin providing shared media conversion functionality for dataset loaders.
    It is used to construct text, image, and audio data from a CustomDatasetT object.
    """

    @property
    def _media_classes(self) -> list[type[MediaT]]:
        """Dynamically get all Media subclasses."""
        return Media.__subclasses__()

    def convert_to_media_objects(
        self, data: CustomDatasetT, name: str = ""
    ) -> dict[str, list[MediaT]]:
        """Convert all custom dataset into media objects.

        Args:
            data: The custom dataset to convert into media objects.
            name: The name of the media field.

        Returns:
            A dictionary of media objects.
        """
        media_objects: dict[str, list[MediaT]] = {}
        for media_class in self._media_classes:
            media_objects[media_class.media_type] = self._convert_to_media_objects(
                data,
                media_class=media_class,
                field=media_class.media_type,
                name=name,
            )
        return media_objects

    def _convert_to_media_objects(
        self,
        data: CustomDatasetT,
        media_class: type[MediaT],
        field: str,
        name: str = "",
    ) -> list[MediaT]:
        """Generic method to construct media objects from a CustomDatasetT object.

        Args:
            data: The custom dataset to construct media objects from.
            media_class: The target media class (Text, Image, or Audio).
            field: The name of the field (e.g., 'text', 'image', 'audio').
            name: The name of the media field.

        Returns:
            A list of media objects.
        """
        # Check singular field first
        value = getattr(data, field, None)
        if value is not None:
            return [media_class(name=name, contents=[value])]

        # Check plural field
        values = getattr(data, f"{field}s", None)
        if values is None or not isinstance(values, Iterable):
            return []

        # If already correct media objects, return as is
        if all(isinstance(v, media_class) for v in values):
            return values

        return [media_class(name=name, contents=values)]

convert_to_media_objects(data, name='')

Convert all custom dataset into media objects.

Parameters:

Name Type Description Default
data CustomDatasetT

The custom dataset to convert into media objects.

required
name str

The name of the media field.

''

Returns:

Type Description
dict[str, list[MediaT]]

A dictionary of media objects.

Source code in aiperf/dataset/loader/mixins.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def convert_to_media_objects(
    self, data: CustomDatasetT, name: str = ""
) -> dict[str, list[MediaT]]:
    """Convert all custom dataset into media objects.

    Args:
        data: The custom dataset to convert into media objects.
        name: The name of the media field.

    Returns:
        A dictionary of media objects.
    """
    media_objects: dict[str, list[MediaT]] = {}
    for media_class in self._media_classes:
        media_objects[media_class.media_type] = self._convert_to_media_objects(
            data,
            media_class=media_class,
            field=media_class.media_type,
            name=name,
        )
    return media_objects

aiperf.dataset.loader.models

CustomDatasetT = TypeVar('CustomDatasetT', bound=(SingleTurn | MultiTurn | RandomPool | MooncakeTrace)) module-attribute

A union type of all custom data types.

MooncakeTrace

Bases: AIPerfBaseModel

Defines the schema for Mooncake trace data.

See https://github.com/kvcache-ai/Mooncake for more details.

Example:

{"timestamp": 1000, "input_length": 10, "output_length": 4, "hash_ids": [123, 456]}
Source code in aiperf/dataset/loader/models.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class MooncakeTrace(AIPerfBaseModel):
    """Defines the schema for Mooncake trace data.

    See https://github.com/kvcache-ai/Mooncake for more details.

    Example:
    ```json
    {"timestamp": 1000, "input_length": 10, "output_length": 4, "hash_ids": [123, 456]}
    ```
    """

    type: Literal[CustomDatasetType.MOONCAKE_TRACE] = CustomDatasetType.MOONCAKE_TRACE

    input_length: int = Field(..., description="The input sequence length of a request")
    output_length: int = Field(
        ..., description="The output sequence length of a request"
    )
    hash_ids: list[int] = Field(..., description="The hash ids of a request")
    timestamp: int = Field(..., description="The timestamp of a request")

MultiTurn

Bases: AIPerfBaseModel

Defines the schema for multi-turn conversations.

The multi-turn custom dataset - supports multi-modal data (e.g. text, image, audio) - supports multi-turn features (e.g. delay, sessions, etc.) - supports client-side batching for each data (e.g. batch size > 1)

Source code in aiperf/dataset/loader/models.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class MultiTurn(AIPerfBaseModel):
    """Defines the schema for multi-turn conversations.

    The multi-turn custom dataset
      - supports multi-modal data (e.g. text, image, audio)
      - supports multi-turn features (e.g. delay, sessions, etc.)
      - supports client-side batching for each data (e.g. batch size > 1)
    """

    type: Literal[CustomDatasetType.MULTI_TURN] = CustomDatasetType.MULTI_TURN

    session_id: str | None = Field(
        None, description="Unique identifier for the conversation session"
    )
    turns: list[SingleTurn] = Field(
        ..., description="List of turns in the conversation"
    )

    @model_validator(mode="after")
    def validate_turns_not_empty(self) -> "MultiTurn":
        """Ensure at least one turn is provided"""
        if not self.turns:
            raise ValueError("At least one turn must be provided")
        return self

validate_turns_not_empty()

Ensure at least one turn is provided

Source code in aiperf/dataset/loader/models.py
92
93
94
95
96
97
@model_validator(mode="after")
def validate_turns_not_empty(self) -> "MultiTurn":
    """Ensure at least one turn is provided"""
    if not self.turns:
        raise ValueError("At least one turn must be provided")
    return self

RandomPool

Bases: AIPerfBaseModel

Defines the schema for random pool data entry.

The random pool custom dataset - supports multi-modal data (e.g. text, image, audio) - supports client-side batching for each data (e.g. batch size > 1) - supports named fields for each modality (e.g. text_field_a, text_field_b, etc.) - DOES NOT support multi-turn or its features (e.g. delay, sessions, etc.)

Source code in aiperf/dataset/loader/models.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class RandomPool(AIPerfBaseModel):
    """Defines the schema for random pool data entry.

    The random pool custom dataset
      - supports multi-modal data (e.g. text, image, audio)
      - supports client-side batching for each data (e.g. batch size > 1)
      - supports named fields for each modality (e.g. text_field_a, text_field_b, etc.)
      - DOES NOT support multi-turn or its features (e.g. delay, sessions, etc.)
    """

    type: Literal[CustomDatasetType.RANDOM_POOL] = CustomDatasetType.RANDOM_POOL

    text: str | None = Field(None, description="Simple text string content")
    texts: list[str] | list[Text] | None = Field(
        None,
        description="List of text strings or Text objects format",
    )
    image: str | None = Field(None, description="Simple image string content")
    images: list[str] | list[Image] | None = Field(
        None,
        description="List of image strings or Image objects format",
    )
    audio: str | None = Field(None, description="Simple audio string content")
    audios: list[str] | list[Audio] | None = Field(
        None,
        description="List of audio strings or Audio objects format",
    )

    @model_validator(mode="after")
    def validate_mutually_exclusive_fields(self) -> "RandomPool":
        """Ensure mutually exclusive fields are not set together"""
        if self.text and self.texts:
            raise ValueError("text and texts cannot be set together")
        if self.image and self.images:
            raise ValueError("image and images cannot be set together")
        if self.audio and self.audios:
            raise ValueError("audio and audios cannot be set together")
        return self

    @model_validator(mode="after")
    def validate_at_least_one_modality(self) -> "RandomPool":
        """Ensure at least one modality is provided"""
        if not any(
            [self.text, self.texts, self.image, self.images, self.audio, self.audios]
        ):
            raise ValueError("At least one modality must be provided")
        return self

validate_at_least_one_modality()

Ensure at least one modality is provided

Source code in aiperf/dataset/loader/models.py
139
140
141
142
143
144
145
146
@model_validator(mode="after")
def validate_at_least_one_modality(self) -> "RandomPool":
    """Ensure at least one modality is provided"""
    if not any(
        [self.text, self.texts, self.image, self.images, self.audio, self.audios]
    ):
        raise ValueError("At least one modality must be provided")
    return self

validate_mutually_exclusive_fields()

Ensure mutually exclusive fields are not set together

Source code in aiperf/dataset/loader/models.py
128
129
130
131
132
133
134
135
136
137
@model_validator(mode="after")
def validate_mutually_exclusive_fields(self) -> "RandomPool":
    """Ensure mutually exclusive fields are not set together"""
    if self.text and self.texts:
        raise ValueError("text and texts cannot be set together")
    if self.image and self.images:
        raise ValueError("image and images cannot be set together")
    if self.audio and self.audios:
        raise ValueError("audio and audios cannot be set together")
    return self

SingleTurn

Bases: AIPerfBaseModel

Defines the schema for single-turn data.

User can use this format to quickly provide a custom single turn dataset. Each line in the file will be treated as a single turn conversation.

The single turn type - supports multi-modal (e.g. text, image, audio) - supports client-side batching for each data (e.g. batch_size > 1) - DOES NOT support multi-turn features (e.g. session_id)

Source code in aiperf/dataset/loader/models.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class SingleTurn(AIPerfBaseModel):
    """Defines the schema for single-turn data.

    User can use this format to quickly provide a custom single turn dataset.
    Each line in the file will be treated as a single turn conversation.

    The single turn type
      - supports multi-modal (e.g. text, image, audio)
      - supports client-side batching for each data (e.g. batch_size > 1)
      - DOES NOT support multi-turn features (e.g. session_id)
    """

    type: Literal[CustomDatasetType.SINGLE_TURN] = CustomDatasetType.SINGLE_TURN

    # TODO (TL-89): investigate if we only want to support single field for each modality
    text: str | None = Field(None, description="Simple text string content")
    texts: list[str] | list[Text] | None = Field(
        None,
        description="List of text strings or Text objects format",
    )
    image: str | None = Field(None, description="Simple image string content")
    images: list[str] | list[Image] | None = Field(
        None,
        description="List of image strings or Image objects format",
    )
    audio: str | None = Field(None, description="Simple audio string content")
    audios: list[str] | list[Audio] | None = Field(
        None,
        description="List of audio strings or Audio objects format",
    )
    timestamp: int | None = Field(
        default=None, description="Timestamp of the turn in milliseconds."
    )
    delay: int | None = Field(
        default=None,
        description="Amount of milliseconds to wait before sending the turn.",
    )
    role: str | None = Field(default=None, description="Role of the turn.")

    @model_validator(mode="after")
    def validate_mutually_exclusive_fields(self) -> "SingleTurn":
        """Ensure mutually exclusive fields are not set together"""
        if self.text and self.texts:
            raise ValueError("text and texts cannot be set together")
        if self.image and self.images:
            raise ValueError("image and images cannot be set together")
        if self.audio and self.audios:
            raise ValueError("audio and audios cannot be set together")
        if self.timestamp and self.delay:
            raise ValueError("timestamp and delay cannot be set together")
        return self

    @model_validator(mode="after")
    def validate_at_least_one_modality(self) -> "SingleTurn":
        """Ensure at least one modality is provided"""
        if not any(
            [self.text, self.texts, self.image, self.images, self.audio, self.audios]
        ):
            raise ValueError("At least one modality must be provided")
        return self

validate_at_least_one_modality()

Ensure at least one modality is provided

Source code in aiperf/dataset/loader/models.py
64
65
66
67
68
69
70
71
@model_validator(mode="after")
def validate_at_least_one_modality(self) -> "SingleTurn":
    """Ensure at least one modality is provided"""
    if not any(
        [self.text, self.texts, self.image, self.images, self.audio, self.audios]
    ):
        raise ValueError("At least one modality must be provided")
    return self

validate_mutually_exclusive_fields()

Ensure mutually exclusive fields are not set together

Source code in aiperf/dataset/loader/models.py
51
52
53
54
55
56
57
58
59
60
61
62
@model_validator(mode="after")
def validate_mutually_exclusive_fields(self) -> "SingleTurn":
    """Ensure mutually exclusive fields are not set together"""
    if self.text and self.texts:
        raise ValueError("text and texts cannot be set together")
    if self.image and self.images:
        raise ValueError("image and images cannot be set together")
    if self.audio and self.audios:
        raise ValueError("audio and audios cannot be set together")
    if self.timestamp and self.delay:
        raise ValueError("timestamp and delay cannot be set together")
    return self

aiperf.dataset.loader.mooncake_trace

MooncakeTraceDatasetLoader

A dataset loader that loads Mooncake trace data from a file.

Loads Mooncake trace data from a file and converts the data into a list of conversations for dataset manager.

Each line in the file represents a single trace entry and will be converted to a separate conversation with a unique session ID.

Example: Fixed schedule version (Each line is a distinct session. Multi-turn is NOT supported)

{"timestamp": 1000, "input_length": 300, "output_length": 40, "hash_ids": [123, 456]}
Source code in aiperf/dataset/loader/mooncake_trace.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@CustomDatasetFactory.register(CustomDatasetType.MOONCAKE_TRACE)
class MooncakeTraceDatasetLoader:
    """A dataset loader that loads Mooncake trace data from a file.

    Loads Mooncake trace data from a file and converts the data into
    a list of conversations for dataset manager.

    Each line in the file represents a single trace entry and will be
    converted to a separate conversation with a unique session ID.

    Example:
    Fixed schedule version (Each line is a distinct session. Multi-turn is NOT supported)
    ```json
    {"timestamp": 1000, "input_length": 300, "output_length": 40, "hash_ids": [123, 456]}
    ```
    """

    def __init__(self, filename: str, prompt_generator: PromptGenerator):
        self.filename = filename
        self.prompt_generator = prompt_generator

    def load_dataset(self) -> dict[str, list[MooncakeTrace]]:
        """Load Mooncake trace data from a file.

        Returns:
            A dictionary of session_id and list of Mooncake trace data.
        """
        data: dict[str, list[MooncakeTrace]] = defaultdict(list)

        with open(self.filename) as f:
            for line in f:
                if (line := line.strip()) == "":
                    continue  # Skip empty lines

                trace_data = MooncakeTrace.model_validate_json(line)
                session_id = str(uuid.uuid4())
                data[session_id].append(trace_data)

        return data

    def convert_to_conversations(
        self, data: dict[str, list[MooncakeTrace]]
    ) -> list[Conversation]:
        """Convert all the Mooncake trace data to conversation objects.

        Args:
            data: A dictionary of session_id and list of Mooncake trace data.

        Returns:
            A list of conversations.
        """
        conversations = []
        for session_id, traces in data.items():
            conversation = Conversation(session_id=session_id)
            for trace in traces:
                prompt = self.prompt_generator.generate(
                    mean=trace.input_length,
                    stddev=0,
                    hash_ids=trace.hash_ids,
                )
                turn = Turn(
                    timestamp=trace.timestamp,
                    texts=[Text(name="text", contents=[prompt])],
                )
                conversation.turns.append(turn)
            conversations.append(conversation)
        return conversations

convert_to_conversations(data)

Convert all the Mooncake trace data to conversation objects.

Parameters:

Name Type Description Default
data dict[str, list[MooncakeTrace]]

A dictionary of session_id and list of Mooncake trace data.

required

Returns:

Type Description
list[Conversation]

A list of conversations.

Source code in aiperf/dataset/loader/mooncake_trace.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def convert_to_conversations(
    self, data: dict[str, list[MooncakeTrace]]
) -> list[Conversation]:
    """Convert all the Mooncake trace data to conversation objects.

    Args:
        data: A dictionary of session_id and list of Mooncake trace data.

    Returns:
        A list of conversations.
    """
    conversations = []
    for session_id, traces in data.items():
        conversation = Conversation(session_id=session_id)
        for trace in traces:
            prompt = self.prompt_generator.generate(
                mean=trace.input_length,
                stddev=0,
                hash_ids=trace.hash_ids,
            )
            turn = Turn(
                timestamp=trace.timestamp,
                texts=[Text(name="text", contents=[prompt])],
            )
            conversation.turns.append(turn)
        conversations.append(conversation)
    return conversations

load_dataset()

Load Mooncake trace data from a file.

Returns:

Type Description
dict[str, list[MooncakeTrace]]

A dictionary of session_id and list of Mooncake trace data.

Source code in aiperf/dataset/loader/mooncake_trace.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def load_dataset(self) -> dict[str, list[MooncakeTrace]]:
    """Load Mooncake trace data from a file.

    Returns:
        A dictionary of session_id and list of Mooncake trace data.
    """
    data: dict[str, list[MooncakeTrace]] = defaultdict(list)

    with open(self.filename) as f:
        for line in f:
            if (line := line.strip()) == "":
                continue  # Skip empty lines

            trace_data = MooncakeTrace.model_validate_json(line)
            session_id = str(uuid.uuid4())
            data[session_id].append(trace_data)

    return data

aiperf.dataset.loader.multi_turn

MultiTurnDatasetLoader

Bases: MediaConversionMixin

A dataset loader that loads multi-turn data from a file.

The multi-turn type - supports multi-modal data (e.g. text, image, audio) - supports multi-turn features (e.g. delay, sessions, etc.) - supports client-side batching for each data (e.g. batch_size > 1)

NOTE: If the user specifies multiple multi-turn entries with same session ID, the loader will group them together. If the timestamps are specified, they will be sorted in ascending order later in the timing manager.

Examples: 1. Simple version

{
    "session_id": "session_123",
    "turns": [
        {"text": "Hello", "image": "url", "delay": 0},
        {"text": "Hi there", "delay": 1000}
    ]
}
  1. Batched version
{
    "session_id": "session_123",
    "turns": [
        {"texts": ["Who are you?", "Hello world"], "images": ["/path/1.png", "/path/2.png"]},
        {"texts": ["What is in the image?", "What is AI?"], "images": ["/path/3.png", "/path/4.png"]}
    ]
}
  1. Fixed schedule version
{
    "session_id": "session_123",
    "turns": [
        {"timestamp": 0, "text": "What is deep learning?"},
        {"timestamp": 1000, "text": "Who are you?"}
    ]
}
  1. Time delayed version
{
    "session_id": "session_123",
    "turns": [
        {"delay": 0, "text": "What is deep learning?"},
        {"delay": 1000, "text": "Who are you?"}
    ]
}
  1. full-featured version (multi-batch, multi-modal, multi-fielded, session-based, etc.)
{
    "session_id": "session_123",
    "turns": [
        {
            "timestamp": 1234,
            "texts": [
                {"name": "text_field_a", "contents": ["hello", "world"]},
                {"name": "text_field_b", "contents": ["hi there"]}
            ],
            "images": [
                {"name": "image_field_a", "contents": ["/path/1.png", "/path/2.png"]},
                {"name": "image_field_b", "contents": ["/path/3.png"]}
            ]
        }
    ]
}
Source code in aiperf/dataset/loader/multi_turn.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
@CustomDatasetFactory.register(CustomDatasetType.MULTI_TURN)
class MultiTurnDatasetLoader(MediaConversionMixin):
    """A dataset loader that loads multi-turn data from a file.

    The multi-turn type
      - supports multi-modal data (e.g. text, image, audio)
      - supports multi-turn features (e.g. delay, sessions, etc.)
      - supports client-side batching for each data (e.g. batch_size > 1)

    NOTE: If the user specifies multiple multi-turn entries with same session ID,
    the loader will group them together. If the timestamps are specified, they will
    be sorted in ascending order later in the timing manager.

    Examples:
    1. Simple version
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {"text": "Hello", "image": "url", "delay": 0},
            {"text": "Hi there", "delay": 1000}
        ]
    }
    ```

    2. Batched version
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {"texts": ["Who are you?", "Hello world"], "images": ["/path/1.png", "/path/2.png"]},
            {"texts": ["What is in the image?", "What is AI?"], "images": ["/path/3.png", "/path/4.png"]}
        ]
    }
    ```

    3. Fixed schedule version
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {"timestamp": 0, "text": "What is deep learning?"},
            {"timestamp": 1000, "text": "Who are you?"}
        ]
    }
    ```

    4. Time delayed version
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {"delay": 0, "text": "What is deep learning?"},
            {"delay": 1000, "text": "Who are you?"}
        ]
    }
    ```

    5. full-featured version (multi-batch, multi-modal, multi-fielded, session-based, etc.)
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {
                "timestamp": 1234,
                "texts": [
                    {"name": "text_field_a", "contents": ["hello", "world"]},
                    {"name": "text_field_b", "contents": ["hi there"]}
                ],
                "images": [
                    {"name": "image_field_a", "contents": ["/path/1.png", "/path/2.png"]},
                    {"name": "image_field_b", "contents": ["/path/3.png"]}
                ]
            }
        ]
    }
    ```
    """

    def __init__(self, filename: str):
        self.filename = filename

    def load_dataset(self) -> dict[str, list[MultiTurn]]:
        """Load multi-turn data from a JSONL file.

        Each line represents a complete multi-turn conversation with its own
        session_id and multiple turns.

        Returns:
            A dictionary mapping session_id to list of MultiTurn objects.
        """
        data: dict[str, list[MultiTurn]] = defaultdict(list)

        with open(self.filename) as f:
            for line in f:
                if (line := line.strip()) == "":
                    continue  # Skip empty lines

                multi_turn_data = MultiTurn.model_validate_json(line)
                session_id = multi_turn_data.session_id or str(uuid.uuid4())
                data[session_id].append(multi_turn_data)

        return data

    def convert_to_conversations(
        self, data: dict[str, list[MultiTurn]]
    ) -> list[Conversation]:
        """Convert multi-turn data to conversation objects.

        Args:
            data: A dictionary mapping session_id to list of MultiTurn objects.

        Returns:
            A list of conversations.
        """
        conversations = []
        for session_id, multi_turns in data.items():
            conversation = Conversation(session_id=session_id)

            # Process all MultiTurn objects for this session
            for multi_turn in multi_turns:
                for single_turn in multi_turn.turns:
                    media = self.convert_to_media_objects(single_turn)
                    conversation.turns.append(
                        Turn(
                            texts=media[MediaType.TEXT],
                            images=media[MediaType.IMAGE],
                            audios=media[MediaType.AUDIO],
                            timestamp=single_turn.timestamp,
                            delay=single_turn.delay,
                            role=single_turn.role,
                        )
                    )
            conversations.append(conversation)
        return conversations

convert_to_conversations(data)

Convert multi-turn data to conversation objects.

Parameters:

Name Type Description Default
data dict[str, list[MultiTurn]]

A dictionary mapping session_id to list of MultiTurn objects.

required

Returns:

Type Description
list[Conversation]

A list of conversations.

Source code in aiperf/dataset/loader/multi_turn.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def convert_to_conversations(
    self, data: dict[str, list[MultiTurn]]
) -> list[Conversation]:
    """Convert multi-turn data to conversation objects.

    Args:
        data: A dictionary mapping session_id to list of MultiTurn objects.

    Returns:
        A list of conversations.
    """
    conversations = []
    for session_id, multi_turns in data.items():
        conversation = Conversation(session_id=session_id)

        # Process all MultiTurn objects for this session
        for multi_turn in multi_turns:
            for single_turn in multi_turn.turns:
                media = self.convert_to_media_objects(single_turn)
                conversation.turns.append(
                    Turn(
                        texts=media[MediaType.TEXT],
                        images=media[MediaType.IMAGE],
                        audios=media[MediaType.AUDIO],
                        timestamp=single_turn.timestamp,
                        delay=single_turn.delay,
                        role=single_turn.role,
                    )
                )
        conversations.append(conversation)
    return conversations

load_dataset()

Load multi-turn data from a JSONL file.

Each line represents a complete multi-turn conversation with its own session_id and multiple turns.

Returns:

Type Description
dict[str, list[MultiTurn]]

A dictionary mapping session_id to list of MultiTurn objects.

Source code in aiperf/dataset/loader/multi_turn.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def load_dataset(self) -> dict[str, list[MultiTurn]]:
    """Load multi-turn data from a JSONL file.

    Each line represents a complete multi-turn conversation with its own
    session_id and multiple turns.

    Returns:
        A dictionary mapping session_id to list of MultiTurn objects.
    """
    data: dict[str, list[MultiTurn]] = defaultdict(list)

    with open(self.filename) as f:
        for line in f:
            if (line := line.strip()) == "":
                continue  # Skip empty lines

            multi_turn_data = MultiTurn.model_validate_json(line)
            session_id = multi_turn_data.session_id or str(uuid.uuid4())
            data[session_id].append(multi_turn_data)

    return data

aiperf.dataset.loader.protocol

aiperf.dataset.loader.random_pool

RandomPoolDatasetLoader

Bases: MediaConversionMixin

A dataset loader that loads data from a single file or a directory.

Each line in the file represents single-turn conversation data, and files create individual pools for random sampling: - Single file: All lines form one single pool (to be randomly sampled from) - Directory: Each file becomes a separate pool, then pools are randomly sampled and merged into conversations later.

The random pool custom dataset - supports multi-modal data (e.g. text, image, audio) - supports client-side batching for each data (e.g. batch size > 1) - supports named fields for each modality (e.g. text_field_a, text_field_b, etc.) - DOES NOT support multi-turn or its features (e.g. delay, sessions, etc.)

Example:

  1. Single file
{"text": "Who are you?", "image": "/path/to/image1.png"}
{"text": "Explain what is the meaning of life.", "image": "/path/to/image2.png"}
...

The file will form a single pool of text and image data that will be used to generate conversations.

  1. Directory

Directory will be useful if user wants to - create multiple pools of different modalities separately (e.g. text, image) - specify different field names for the same modality.

data/queries.jsonl

{"texts": [{"name": "query", "contents": ["Who are you?"]}]}
{"texts": [{"name": "query", "contents": ["What is the meaning of life?"]}]}
...

data/passages.jsonl

{"texts": [{"name": "passage", "contents": ["I am a cat."]}]}
{"texts": [{"name": "passage", "contents": ["I am a dog."]}]}
...

The loader will create two separate pools for each file: queries and passages. Each pool is a text dataset with a different field name (e.g. query, passage), and loader will later sample from these two pools to create conversations.

Source code in aiperf/dataset/loader/random_pool.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
@CustomDatasetFactory.register(CustomDatasetType.RANDOM_POOL)
class RandomPoolDatasetLoader(MediaConversionMixin):
    """A dataset loader that loads data from a single file or a directory.

    Each line in the file represents single-turn conversation data,
    and files create individual pools for random sampling:
      - Single file: All lines form one single pool (to be randomly sampled from)
      - Directory: Each file becomes a separate pool, then pools are randomly sampled
                   and merged into conversations later.

    The random pool custom dataset
      - supports multi-modal data (e.g. text, image, audio)
      - supports client-side batching for each data (e.g. batch size > 1)
      - supports named fields for each modality (e.g. text_field_a, text_field_b, etc.)
      - DOES NOT support multi-turn or its features (e.g. delay, sessions, etc.)

    Example:

    1. Single file
    ```jsonl
    {"text": "Who are you?", "image": "/path/to/image1.png"}
    {"text": "Explain what is the meaning of life.", "image": "/path/to/image2.png"}
    ...
    ```
    The file will form a single pool of text and image data that will be used
    to generate conversations.

    2. Directory

    Directory will be useful if user wants to
      - create multiple pools of different modalities separately (e.g. text, image)
      - specify different field names for the same modality.

    data/queries.jsonl
    ```jsonl
    {"texts": [{"name": "query", "contents": ["Who are you?"]}]}
    {"texts": [{"name": "query", "contents": ["What is the meaning of life?"]}]}
    ...
    ```

    data/passages.jsonl
    ```jsonl
    {"texts": [{"name": "passage", "contents": ["I am a cat."]}]}
    {"texts": [{"name": "passage", "contents": ["I am a dog."]}]}
    ...
    ```

    The loader will create two separate pools for each file: queries and passages.
    Each pool is a text dataset with a different field name (e.g. query, passage),
    and loader will later sample from these two pools to create conversations.
    """

    def __init__(self, filename: str, num_conversations: int = 1):
        self.filename = filename
        self.num_conversations = num_conversations

    def load_dataset(self) -> dict[Filename, list[RandomPool]]:
        """Load random pool data from a file or directory.

        If filename is a file, reads and parses using the RandomPool model.
        If filename is a directory, reads each file in the directory and merges
        items with different modality names into combined RandomPool objects.

        Returns:
            A dictionary mapping filename to list of RandomPool objects.
        """
        path = Path(self.filename)

        if path.is_file():
            dataset_pool = self._load_dataset_from_file(path)
            return {path.name: dataset_pool}

        return self._load_dataset_from_dir(path)

    def _load_dataset_from_file(self, file_path: Path) -> list[RandomPool]:
        """Load random pool data from a single file.

        Args:
            file_path: The path to the file containing the data.

        Returns:
            A list of RandomPool objects.
        """
        dataset_pool: list[RandomPool] = []

        with open(file_path) as f:
            for line in f:
                if (line := line.strip()) == "":
                    continue  # Skip empty lines

                random_pool_data = RandomPool.model_validate_json(line)
                dataset_pool.append(random_pool_data)

        return dataset_pool

    def _load_dataset_from_dir(
        self, dir_path: Path
    ) -> dict[Filename, list[RandomPool]]:
        """Load random pool data from all files in a directory.

        Args:
            dir_path: The path to the directory containing the files.

        Returns:
            A dictionary mapping filename to list of RandomPool objects.
        """
        data: dict[Filename, list[RandomPool]] = defaultdict(list)

        for file_path in dir_path.iterdir():
            if file_path.is_file():
                dataset_pool = self._load_dataset_from_file(file_path)
                data[file_path.name].extend(dataset_pool)

        return data

    def convert_to_conversations(
        self, data: dict[Filename, list[RandomPool]]
    ) -> list[Conversation]:
        """Convert random pool data to conversation objects.

        Each RandomPool entry becomes a single-turn conversation with a unique session ID.

        Args:
            data: A dictionary mapping filename to list of RandomPool objects.

        Returns:
            A list of conversations.
        """
        conversations = [
            Conversation(session_id=str(uuid.uuid4()))
            for _ in range(self.num_conversations)
        ]

        # F x N (F: num of files, N: num of conversations)
        sampled_dataset: dict[Filename, list[Turn]] = {}

        # Randomly sample (with replacement) from each dataset pool
        for filename, dataset_pool in data.items():
            samples = random.choices(dataset_pool, k=self.num_conversations)
            turns: list[Turn] = []
            for sample in samples:
                media = self.convert_to_media_objects(sample, name=Path(filename).stem)
                turns.append(
                    Turn(
                        texts=media[MediaType.TEXT],
                        images=media[MediaType.IMAGE],
                        audios=media[MediaType.AUDIO],
                    )
                )
            sampled_dataset[filename] = turns

        # Merge turns for each conversation
        for i, batched_turns in enumerate(zip(*sampled_dataset.values(), strict=False)):
            turn = self._merge_turns(batched_turns)
            conversations[i].turns.append(turn)

        return conversations

    def _merge_turns(self, turns: list[Turn]) -> Turn:
        """Merge turns into a single turn.

        Args:
            turns: A list of turns.

        Returns:
            A single turn.
        """
        merged_turn = Turn(
            texts=[text for turn in turns for text in turn.texts],
            images=[image for turn in turns for image in turn.images],
            audios=[audio for turn in turns for audio in turn.audios],
        )
        return merged_turn

convert_to_conversations(data)

Convert random pool data to conversation objects.

Each RandomPool entry becomes a single-turn conversation with a unique session ID.

Parameters:

Name Type Description Default
data dict[Filename, list[RandomPool]]

A dictionary mapping filename to list of RandomPool objects.

required

Returns:

Type Description
list[Conversation]

A list of conversations.

Source code in aiperf/dataset/loader/random_pool.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def convert_to_conversations(
    self, data: dict[Filename, list[RandomPool]]
) -> list[Conversation]:
    """Convert random pool data to conversation objects.

    Each RandomPool entry becomes a single-turn conversation with a unique session ID.

    Args:
        data: A dictionary mapping filename to list of RandomPool objects.

    Returns:
        A list of conversations.
    """
    conversations = [
        Conversation(session_id=str(uuid.uuid4()))
        for _ in range(self.num_conversations)
    ]

    # F x N (F: num of files, N: num of conversations)
    sampled_dataset: dict[Filename, list[Turn]] = {}

    # Randomly sample (with replacement) from each dataset pool
    for filename, dataset_pool in data.items():
        samples = random.choices(dataset_pool, k=self.num_conversations)
        turns: list[Turn] = []
        for sample in samples:
            media = self.convert_to_media_objects(sample, name=Path(filename).stem)
            turns.append(
                Turn(
                    texts=media[MediaType.TEXT],
                    images=media[MediaType.IMAGE],
                    audios=media[MediaType.AUDIO],
                )
            )
        sampled_dataset[filename] = turns

    # Merge turns for each conversation
    for i, batched_turns in enumerate(zip(*sampled_dataset.values(), strict=False)):
        turn = self._merge_turns(batched_turns)
        conversations[i].turns.append(turn)

    return conversations

load_dataset()

Load random pool data from a file or directory.

If filename is a file, reads and parses using the RandomPool model. If filename is a directory, reads each file in the directory and merges items with different modality names into combined RandomPool objects.

Returns:

Type Description
dict[Filename, list[RandomPool]]

A dictionary mapping filename to list of RandomPool objects.

Source code in aiperf/dataset/loader/random_pool.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def load_dataset(self) -> dict[Filename, list[RandomPool]]:
    """Load random pool data from a file or directory.

    If filename is a file, reads and parses using the RandomPool model.
    If filename is a directory, reads each file in the directory and merges
    items with different modality names into combined RandomPool objects.

    Returns:
        A dictionary mapping filename to list of RandomPool objects.
    """
    path = Path(self.filename)

    if path.is_file():
        dataset_pool = self._load_dataset_from_file(path)
        return {path.name: dataset_pool}

    return self._load_dataset_from_dir(path)

aiperf.dataset.loader.single_turn

SingleTurnDatasetLoader

Bases: MediaConversionMixin

A dataset loader that loads single turn data from a file.

The single turn type - supports multi-modal data (e.g. text, image, audio) - supports client-side batching for each data (e.g. batch_size > 1) - DOES NOT support multi-turn features (e.g. delay, sessions, etc.)

Examples: 1. Single-batch, text only

{"text": "What is deep learning?"}
  1. Single-batch, multi-modal
{"text": "What is in the image?", "image": "/path/to/image.png"}
  1. Multi-batch, multi-modal
{"texts": ["Who are you?", "Hello world"], "images": ["/path/to/image.png", "/path/to/image2.png"]}
  1. Fixed schedule version
{"timestamp": 0, "text": "What is deep learning?"},
{"timestamp": 1000, "text": "Who are you?"},
{"timestamp": 2000, "text": "What is AI?"}
  1. Time delayed version
{"delay": 0, "text": "What is deep learning?"},
{"delay": 1234, "text": "Who are you?"}
  1. Full-featured version (Multi-batch, multi-modal, multi-fielded)
{
    "texts": [
        {"name": "text_field_A", "contents": ["Hello", "World"]},
        {"name": "text_field_B", "contents": ["Hi there"]}
    ],
    "images": [
        {"name": "image_field_A", "contents": ["/path/1.png", "/path/2.png"]},
        {"name": "image_field_B", "contents": ["/path/3.png"]}
    ]
}
Source code in aiperf/dataset/loader/single_turn.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@CustomDatasetFactory.register(CustomDatasetType.SINGLE_TURN)
class SingleTurnDatasetLoader(MediaConversionMixin):
    """A dataset loader that loads single turn data from a file.

    The single turn type
      - supports multi-modal data (e.g. text, image, audio)
      - supports client-side batching for each data (e.g. batch_size > 1)
      - DOES NOT support multi-turn features (e.g. delay, sessions, etc.)

    Examples:
    1. Single-batch, text only
    ```json
    {"text": "What is deep learning?"}
    ```

    2. Single-batch, multi-modal
    ```json
    {"text": "What is in the image?", "image": "/path/to/image.png"}
    ```

    3. Multi-batch, multi-modal
    ```json
    {"texts": ["Who are you?", "Hello world"], "images": ["/path/to/image.png", "/path/to/image2.png"]}
    ```

    4. Fixed schedule version
    ```json
    {"timestamp": 0, "text": "What is deep learning?"},
    {"timestamp": 1000, "text": "Who are you?"},
    {"timestamp": 2000, "text": "What is AI?"}
    ```

    5. Time delayed version
    ```json
    {"delay": 0, "text": "What is deep learning?"},
    {"delay": 1234, "text": "Who are you?"}
    ```

    6. Full-featured version (Multi-batch, multi-modal, multi-fielded)
    ```json
    {
        "texts": [
            {"name": "text_field_A", "contents": ["Hello", "World"]},
            {"name": "text_field_B", "contents": ["Hi there"]}
        ],
        "images": [
            {"name": "image_field_A", "contents": ["/path/1.png", "/path/2.png"]},
            {"name": "image_field_B", "contents": ["/path/3.png"]}
        ]
    }
    ```
    """

    def __init__(self, filename: str):
        self.filename = filename

    def load_dataset(self) -> dict[str, list[SingleTurn]]:
        """Load single-turn data from a JSONL file.

        Each line represents a single turn conversation. Multiple turns with
        the same session_id (or generated UUID) are grouped together.

        Returns:
            A dictionary mapping session_id to list of CustomData.
        """
        data: dict[str, list[SingleTurn]] = defaultdict(list)

        with open(self.filename) as f:
            for line in f:
                if (line := line.strip()) == "":
                    continue  # Skip empty lines

                single_turn_data = SingleTurn.model_validate_json(line)
                session_id = str(uuid.uuid4())
                data[session_id].append(single_turn_data)

        return data

    def convert_to_conversations(
        self, data: dict[str, list[SingleTurn]]
    ) -> list[Conversation]:
        """Convert single turn data to conversation objects.

        Args:
            data: A dictionary mapping session_id to list of SingleTurn objects.

        Returns:
            A list of conversations.
        """
        conversations = []
        for session_id, single_turns in data.items():
            conversation = Conversation(session_id=session_id)
            for single_turn in single_turns:
                media = self.convert_to_media_objects(single_turn)
                conversation.turns.append(
                    Turn(
                        texts=media[MediaType.TEXT],
                        images=media[MediaType.IMAGE],
                        audios=media[MediaType.AUDIO],
                        timestamp=single_turn.timestamp,
                        delay=single_turn.delay,
                        role=single_turn.role,
                    )
                )
            conversations.append(conversation)
        return conversations

convert_to_conversations(data)

Convert single turn data to conversation objects.

Parameters:

Name Type Description Default
data dict[str, list[SingleTurn]]

A dictionary mapping session_id to list of SingleTurn objects.

required

Returns:

Type Description
list[Conversation]

A list of conversations.

Source code in aiperf/dataset/loader/single_turn.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def convert_to_conversations(
    self, data: dict[str, list[SingleTurn]]
) -> list[Conversation]:
    """Convert single turn data to conversation objects.

    Args:
        data: A dictionary mapping session_id to list of SingleTurn objects.

    Returns:
        A list of conversations.
    """
    conversations = []
    for session_id, single_turns in data.items():
        conversation = Conversation(session_id=session_id)
        for single_turn in single_turns:
            media = self.convert_to_media_objects(single_turn)
            conversation.turns.append(
                Turn(
                    texts=media[MediaType.TEXT],
                    images=media[MediaType.IMAGE],
                    audios=media[MediaType.AUDIO],
                    timestamp=single_turn.timestamp,
                    delay=single_turn.delay,
                    role=single_turn.role,
                )
            )
        conversations.append(conversation)
    return conversations

load_dataset()

Load single-turn data from a JSONL file.

Each line represents a single turn conversation. Multiple turns with the same session_id (or generated UUID) are grouped together.

Returns:

Type Description
dict[str, list[SingleTurn]]

A dictionary mapping session_id to list of CustomData.

Source code in aiperf/dataset/loader/single_turn.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def load_dataset(self) -> dict[str, list[SingleTurn]]:
    """Load single-turn data from a JSONL file.

    Each line represents a single turn conversation. Multiple turns with
    the same session_id (or generated UUID) are grouped together.

    Returns:
        A dictionary mapping session_id to list of CustomData.
    """
    data: dict[str, list[SingleTurn]] = defaultdict(list)

    with open(self.filename) as f:
        for line in f:
            if (line := line.strip()) == "":
                continue  # Skip empty lines

            single_turn_data = SingleTurn.model_validate_json(line)
            session_id = str(uuid.uuid4())
            data[session_id].append(single_turn_data)

    return data

aiperf.dataset.utils

check_file_exists(filename)

Verifies that the file exists.

Parameters:

Name Type Description Default
filename

The file path to verify.

required

Raises:

Type Description
FileNotFoundError

If the file does not exist.

Source code in aiperf/dataset/utils.py
18
19
20
21
22
23
24
25
26
27
28
def check_file_exists(filename: Path) -> None:
    """Verifies that the file exists.

    Args:
        filename : The file path to verify.

    Raises:
        FileNotFoundError: If the file does not exist.
    """
    if not filename.exists():
        raise FileNotFoundError(f"The file '{filename}' does not exist.")

encode_image(img, format)

Encodes an image into base64 encoded string.

Parameters:

Name Type Description Default
img Image

The PIL Image object to encode.

required
format str

The image format to use (e.g., "JPEG", "PNG").

required

Returns:

Type Description
str

A base64 encoded string representation of the image.

Source code in aiperf/dataset/utils.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def encode_image(img: Image, format: str) -> str:
    """Encodes an image into base64 encoded string.

    Args:
        img: The PIL Image object to encode.
        format: The image format to use (e.g., "JPEG", "PNG").

    Returns:
        A base64 encoded string representation of the image.
    """
    # JPEG does not support P or RGBA mode (commonly used for PNG) so it needs
    # to be converted to RGB before an image can be saved as JPEG format.
    if format == "JPEG" and img.mode != "RGB":
        img = img.convert("RGB")

    buffer = BytesIO()
    img.save(buffer, format=format)
    return base64.b64encode(buffer.getvalue()).decode("utf-8")

load_json_str(json_str, func=lambda x: x)

Deserializes JSON encoded string into Python object.

Parameters:

Name Type Description Default
json_str str

JSON encoded string

required
func Callable

A function that takes deserialized JSON object. This can be used to run validation checks on the object. Defaults to identity function.

lambda x: x

Returns:

Type Description
dict[str, Any]

The deserialized JSON object.

Raises:

Type Description
RuntimeError

If the JSON string is invalid.

Source code in aiperf/dataset/utils.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def load_json_str(json_str: str, func: Callable = lambda x: x) -> dict[str, Any]:
    """Deserializes JSON encoded string into Python object.

    Args:
        json_str: JSON encoded string
        func: A function that takes deserialized JSON object. This can be used to
            run validation checks on the object. Defaults to identity function.

    Returns:
        The deserialized JSON object.

    Raises:
        RuntimeError: If the JSON string is invalid.
    """
    try:
        # TODO: Consider using orjson for faster JSON parsing
        return func(json.loads(json_str))
    except json.JSONDecodeError as e:
        snippet = json_str[:200] + ("..." if len(json_str) > 200 else "")
        raise RuntimeError(f"Failed to parse JSON string: '{snippet}'") from e

open_image(filename)

Opens an image file.

Parameters:

Name Type Description Default
filename

The file path to open.

required

Returns:

Type Description
Image

The opened PIL Image object.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

Source code in aiperf/dataset/utils.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def open_image(filename: str) -> Image:
    """Opens an image file.

    Args:
        filename : The file path to open.

    Returns:
        The opened PIL Image object.

    Raises:
        FileNotFoundError: If the file does not exist.
    """
    check_file_exists(Path(filename))
    img = Image.open(filename)

    if img.format is None:
        raise RuntimeError(f"Failed to determine image format of '{filename}'.")

    if img.format.upper() not in list(ImageFormat):
        raise RuntimeError(
            f"'{img.format}' is not one of the supported image formats: "
            f"{', '.join(ImageFormat)}"
        )
    return img

sample_normal(mean, stddev, lower=-np.inf, upper=np.inf)

Sample from a normal distribution with support for bounds using rejection sampling.

Parameters:

Name Type Description Default
mean float

The mean of the normal distribution.

required
stddev float

The standard deviation of the normal distribution.

required
lower float

The lower bound of the distribution.

-inf
upper float

The upper bound of the distribution.

inf

Returns:

Type Description
int

An integer sampled from the distribution.

Source code in aiperf/dataset/utils.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def sample_normal(
    mean: float, stddev: float, lower: float = -np.inf, upper: float = np.inf
) -> int:
    """Sample from a normal distribution with support for bounds using rejection sampling.

    Args:
        mean: The mean of the normal distribution.
        stddev: The standard deviation of the normal distribution.
        lower: The lower bound of the distribution.
        upper: The upper bound of the distribution.

    Returns:
        An integer sampled from the distribution.
    """
    while True:
        n = np.random.normal(mean, stddev)
        if lower <= n <= upper:
            return n

sample_positive_normal(mean, stddev)

Sample from a normal distribution ensuring positive values without distorting the distribution.

Parameters:

Name Type Description Default
mean float

Mean value for the normal distribution

required
stddev float

Standard deviation for the normal distribution

required

Returns:

Type Description
float

A positive sample from the normal distribution

Raises:

Type Description
ValueError

If mean is less than 0

Source code in aiperf/dataset/utils.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def sample_positive_normal(mean: float, stddev: float) -> float:
    """Sample from a normal distribution ensuring positive values
    without distorting the distribution.

    Args:
        mean: Mean value for the normal distribution
        stddev: Standard deviation for the normal distribution

    Returns:
        A positive sample from the normal distribution

    Raises:
        ValueError: If mean is less than 0
    """
    if mean < 0:
        raise ValueError(f"Mean value ({mean}) should be greater than 0")
    return sample_normal(mean, stddev, lower=0)

sample_positive_normal_integer(mean, stddev)

Sample a random positive integer from a normal distribution.

Parameters:

Name Type Description Default
mean float

The mean of the normal distribution.

required
stddev float

The standard deviation of the normal distribution.

required

Returns:

Type Description
int

A positive integer sampled from the distribution. If the sampled

int

number is less than 1, it returns 1.

Source code in aiperf/dataset/utils.py
138
139
140
141
142
143
144
145
146
147
148
149
def sample_positive_normal_integer(mean: float, stddev: float) -> int:
    """Sample a random positive integer from a normal distribution.

    Args:
        mean: The mean of the normal distribution.
        stddev: The standard deviation of the normal distribution.

    Returns:
        A positive integer sampled from the distribution. If the sampled
        number is less than 1, it returns 1.
    """
    return math.ceil(sample_positive_normal(mean, stddev))

aiperf.exporters.console_error_exporter

ConsoleErrorExporter

A class that exports error data to the console

Source code in aiperf/exporters/console_error_exporter.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@DataExporterFactory.register(DataExporterType.CONSOLE_ERROR)
class ConsoleErrorExporter:
    """A class that exports error data to the console"""

    def __init__(self, exporter_config: ExporterConfig, **kwargs):
        self._results = exporter_config.results

    async def export(self, width: int | None = None) -> None:
        if not self._results.error_summary:
            return

        table = Table(title=self._get_title(), width=width)
        table.add_column("Code", justify="right", style="yellow")
        table.add_column("Type", justify="right", style="yellow")
        table.add_column("Message", justify="left", style="yellow")
        table.add_column("Count", justify="right", style="yellow")
        self._construct_table(table, self._results.error_summary)

        console = Console()
        console.print("\n")
        console.print(table)
        console.file.flush()

    def _construct_table(
        self, table: Table, errors_by_type: list[ErrorDetailsCount]
    ) -> None:
        for error_details_count in errors_by_type:
            table.add_row(*self._format_row(error_details_count))

    def _format_row(self, error_details_count: ErrorDetailsCount) -> list[str]:
        details = error_details_count.error_details
        count = error_details_count.count

        return [
            str(details.code) if details.code else "[dim]N/A[/dim]",
            str(details.type) if details.type else "[dim]N/A[/dim]",
            str(details.message),
            f"{count:,}",
        ]

    def _get_title(self) -> str:
        return "[bold][red]NVIDIA AIPerf | Error Summary[/red][/bold]"

aiperf.exporters.console_exporter

ConsoleExporter

Bases: AIPerfLoggerMixin

A class that exports data to the console

Source code in aiperf/exporters/console_exporter.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@implements_protocol(DataExporterProtocol)
@DataExporterFactory.register(DataExporterType.CONSOLE)
class ConsoleExporter(AIPerfLoggerMixin):
    """A class that exports data to the console"""

    STAT_COLUMN_KEYS = ["avg", "min", "max", "p99", "p90", "p75", "std"]

    def __init__(self, exporter_config: ExporterConfig, **kwargs) -> None:
        super().__init__(**kwargs)
        self._results = exporter_config.results
        self._endpoint_type = exporter_config.user_config.endpoint.type
        self._show_internal_metrics = (
            exporter_config.user_config.output.show_internal_metrics
        )

    async def export(self, width: int | None = None) -> None:
        if not self._results.records:
            self.warning("No records to export")
            return

        table = Table(title=self._get_title(), width=width)
        table.add_column("Metric", justify="right", style="cyan")
        for key in self.STAT_COLUMN_KEYS:
            table.add_column(key, justify="right", style="green")
        self._construct_table(table, self._results.records)

        console = Console()
        console.print("\n")
        console.print(table)
        console.file.flush()

    def _construct_table(self, table: Table, records: list[MetricResult]) -> None:
        records = sorted(
            records,
            key=lambda x: MetricRegistry.get_class(x.tag).display_order or sys.maxsize,
        )
        for record in records:
            if self._should_skip(record):
                continue
            table.add_row(*self._format_row(record))

    def _should_skip(self, record: MetricResult) -> bool:
        metric_class = MetricRegistry.get_class(record.tag)
        if metric_class.has_flags(MetricFlags.ERROR_ONLY):
            return True
        return (
            metric_class.has_flags(MetricFlags.HIDDEN)
            and not self._show_internal_metrics
        )

    def _format_row(self, record: MetricResult) -> list[str]:
        metric_class = MetricRegistry.get_class(record.tag)
        display_unit = metric_class.display_unit or metric_class.unit
        row = [f"{record.header} ({display_unit})"]
        for stat in self.STAT_COLUMN_KEYS:
            value = getattr(record, stat, None)
            if value is None:
                row.append("[dim]N/A[/dim]")
                continue

            # Count should never be unit-converted (it's always just the number of records)
            if display_unit != metric_class.unit and stat != "count":
                try:
                    value = metric_class.unit.convert_to(display_unit, value)
                except MetricUnitError as e:
                    self.warning(f"Error during unit conversion: {e}")

            if isinstance(value, datetime):
                value = value.strftime("%Y-%m-%d %H:%M:%S")
            elif isinstance(value, int | float):
                value = f"{value:,.2f}"
            else:
                value = str(value)
            row.append(value)
        return row

    def _get_title(self) -> str:
        return f"NVIDIA AIPerf | {self._endpoint_type.metrics_title}"

aiperf.exporters.exporter_config

aiperf.exporters.exporter_manager

ExporterManager

Bases: AIPerfLoggerMixin

ExporterManager is responsible for exporting records using all registered data exporters.

Source code in aiperf/exporters/exporter_manager.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class ExporterManager(AIPerfLoggerMixin):
    """
    ExporterManager is responsible for exporting records using all
    registered data exporters.
    """

    def __init__(self, results: ProfileResults, input_config: UserConfig, **kwargs):
        super().__init__(**kwargs)
        self._results = results
        self._input_config = input_config

    async def export_all(self) -> None:
        self.info("Exporting all records")
        tasks: set[asyncio.Task] = set()
        exporter_config = ExporterConfig(
            results=self._results,
            user_config=self._input_config,
        )

        def task_done_callback(task: asyncio.Task) -> None:
            self.debug(lambda: f"Task done: {task}")
            if task.exception():
                self.error(f"Error exporting records: {task.exception()}")
            else:
                self.debug(f"Exported records: {task.result()}")
            tasks.discard(task)

        for exporter_type in DataExporterFactory.get_all_class_types():
            exporter = DataExporterFactory.create_instance(
                exporter_type, exporter_config=exporter_config
            )
            self.debug(f"Creating task for exporter: {exporter_type}")
            task = asyncio.create_task(exporter.export())
            tasks.add(task)
            task.add_done_callback(task_done_callback)

        await asyncio.gather(*tasks, return_exceptions=True)
        self.debug("Exporting all records completed")

aiperf.exporters.json_exporter

JsonExportData

Bases: BaseModel

Data to be exported to a JSON file.

Source code in aiperf/exporters/json_exporter.py
21
22
23
24
25
26
27
28
29
class JsonExportData(BaseModel):
    """Data to be exported to a JSON file."""

    input_config: UserConfig | None = None
    records: dict[MetricTagT, MetricResult] | None = None
    was_cancelled: bool | None = None
    error_summary: list[ErrorDetailsCount] | None = None
    start_time: datetime | None = None
    end_time: datetime | None = None

JsonExporter

Bases: AIPerfLoggerMixin

A class to export records to a JSON file.

Source code in aiperf/exporters/json_exporter.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@DataExporterFactory.register(DataExporterType.JSON)
@implements_protocol(DataExporterProtocol)
class JsonExporter(AIPerfLoggerMixin):
    """
    A class to export records to a JSON file.
    """

    def __init__(self, exporter_config: ExporterConfig, **kwargs) -> None:
        super().__init__(**kwargs)
        self.debug(lambda: f"Initializing JsonExporter with config: {exporter_config}")
        self._results = exporter_config.results
        self._output_directory = exporter_config.user_config.output.artifact_directory
        self._input_config = exporter_config.user_config

    async def export(self) -> None:
        filename = self._output_directory / "profile_export_aiperf.json"
        self._output_directory.mkdir(parents=True, exist_ok=True)

        start_time = (
            datetime.fromtimestamp(self._results.start_ns / NANOS_PER_SECOND)
            if self._results.start_ns
            else None
        )
        end_time = (
            datetime.fromtimestamp(self._results.end_ns / NANOS_PER_SECOND)
            if self._results.end_ns
            else None
        )

        export_data = JsonExportData(
            input_config=self._input_config,
            records={record.tag: record for record in self._results.records},
            was_cancelled=self._results.was_cancelled,
            error_summary=self._results.error_summary,
            start_time=start_time,
            end_time=end_time,
        )

        self.debug(lambda: f"Exporting data to JSON file: {export_data}")
        export_data_json = export_data.model_dump_json(indent=2, exclude_unset=True)
        async with aiofiles.open(filename, "w") as f:
            await f.write(export_data_json)

aiperf.metrics.base_aggregate_metric

BaseAggregateMetric

Bases: Generic[MetricValueTypeVarT], BaseMetric[MetricValueTypeVarT], ABC

A base class for aggregate metrics. These metrics keep track of a value or list of values over time.

This metric type is unique in the fact that it is split into 2 distinct phases of processing, in order to support distributed processing.

For each distributed RecordProcessor, an instance of this class is created. This instance is passed the record and the existing record metrics, and is responsible for returning the individual value for that record. It should not use or update the aggregate value here.

The ResultsProcessor creates a singleton instance of this class, which will be used to aggregate the results from the distributed RecordProcessors. It calls the _aggregate_value method, which each metric class must implement to define how values from different processes are aggregated, such as summing the values, or taking the min/max/average, etc.

Examples:

class RequestCountMetric(BaseAggregateMetric[int]):
    # ... Metric attributes ...

    def _parse_record(self, record: ParsedResponseRecord, record_metrics: MetricRecordDict) -> int:
        # We just return 1 since we are tracking the total count, and this is a single request.
        return 1

    def _aggregate_value(self, value: int) -> None:
        # We add the value to the aggregate value.
        self._value += value
Source code in aiperf/metrics/base_aggregate_metric.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class BaseAggregateMetric(
    Generic[MetricValueTypeVarT], BaseMetric[MetricValueTypeVarT], ABC
):
    """A base class for aggregate metrics. These metrics keep track of a value or list of values over time.

    This metric type is unique in the fact that it is split into 2 distinct phases of processing, in order to support distributed processing.

    For each distributed RecordProcessor, an instance of this class is created. This instance is passed the record and the existing record metrics,
    and is responsible for returning the individual value for that record. It should not use or update the aggregate value here.

    The ResultsProcessor creates a singleton instance of this class, which will be used to aggregate the results from the distributed
    RecordProcessors. It calls the `_aggregate_value` method, which each metric class must implement to define how values from different
    processes are aggregated, such as summing the values, or taking the min/max/average, etc.

    Examples:
    ```python
    class RequestCountMetric(BaseAggregateMetric[int]):
        # ... Metric attributes ...

        def _parse_record(self, record: ParsedResponseRecord, record_metrics: MetricRecordDict) -> int:
            # We just return 1 since we are tracking the total count, and this is a single request.
            return 1

        def _aggregate_value(self, value: int) -> None:
            # We add the value to the aggregate value.
            self._value += value
    ```
    """

    type = MetricType.AGGREGATE

    def __init__(self, default_value: MetricValueTypeVarT | None = None) -> None:
        """Initialize the metric with optionally with a default value. If no default value is provided,
        the default value is automatically set based on the value type."""
        self._value: MetricValueTypeVarT = (  # type: ignore
            default_value
            if default_value is not None
            else self.value_type.default_factory()
        )
        self.aggregate_value: Callable[[MetricValueTypeVarT], None] = (
            self._aggregate_value
        )
        super().__init__()

    @property
    def current_value(self) -> MetricValueTypeVarT:
        """Get the current value of the metric."""
        return self._value

    def parse_record(
        self, record: ParsedResponseRecord, record_metrics: MetricRecordDict
    ) -> MetricValueTypeVarT:
        """Parse the record and return the individual value.

        Raises:
            ValueError: If the metric cannot be computed for the given inputs.
        """
        self._require_valid_record(record)
        self._check_metrics(record_metrics)
        return self._parse_record(record, record_metrics)

    @abstractmethod
    def _parse_record(
        self, record: ParsedResponseRecord, record_metrics: MetricRecordDict
    ) -> MetricValueTypeVarT:
        """Parse the record and *return* the individual value base on this record, and this record alone. This
        method is implemented by subclasses.

        NOTE: Do not use or update the aggregate value here.

        This method is called after the required metrics are checked, so it can assume that the required metrics are available.
        This method is called after the record is checked, so it can assume that the record is valid.

        Raises:
            ValueError: If the metric cannot be computed for the given inputs.
        """
        raise NotImplementedError("Subclasses must implement this method")

    # NOTE: This method does not return a value on purpose, as a hint to the user that the
    #       internal value is supposed to be updated.
    @abstractmethod
    def _aggregate_value(self, value: MetricValueTypeVarT) -> None:
        """Aggregate the metric value. This method is implemented by subclasses.

        This method is called with the result value from the `_parse_record` method, from each distributed record processor.
        It is the responsibility of each metric class to implement how values from different processes are aggregated, such
        as summing the values, or taking the min/max/average, etc.

        NOTE: The order of the values is not guaranteed.
        """
        raise NotImplementedError("Subclasses must implement this method")

current_value property

Get the current value of the metric.

__init__(default_value=None)

Initialize the metric with optionally with a default value. If no default value is provided, the default value is automatically set based on the value type.

Source code in aiperf/metrics/base_aggregate_metric.py
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(self, default_value: MetricValueTypeVarT | None = None) -> None:
    """Initialize the metric with optionally with a default value. If no default value is provided,
    the default value is automatically set based on the value type."""
    self._value: MetricValueTypeVarT = (  # type: ignore
        default_value
        if default_value is not None
        else self.value_type.default_factory()
    )
    self.aggregate_value: Callable[[MetricValueTypeVarT], None] = (
        self._aggregate_value
    )
    super().__init__()

parse_record(record, record_metrics)

Parse the record and return the individual value.

Raises:

Type Description
ValueError

If the metric cannot be computed for the given inputs.

Source code in aiperf/metrics/base_aggregate_metric.py
62
63
64
65
66
67
68
69
70
71
72
def parse_record(
    self, record: ParsedResponseRecord, record_metrics: MetricRecordDict
) -> MetricValueTypeVarT:
    """Parse the record and return the individual value.

    Raises:
        ValueError: If the metric cannot be computed for the given inputs.
    """
    self._require_valid_record(record)
    self._check_metrics(record_metrics)
    return self._parse_record(record, record_metrics)

aiperf.metrics.base_derived_metric

BaseDerivedMetric

Bases: Generic[MetricValueTypeVarT], BaseMetric[MetricValueTypeVarT], ABC

A base class for derived metrics. These metrics are computed from other metrics, and do not require any knowledge of the individual records. The final results will be a single computed value (or list of values).

NOTE: The generic type can be a list of values, or a single value.

Examples:

class RequestThroughputMetric(BaseDerivedMetric[float]):
    # ... Metric attributes ...

    def _derive_value(self, metric_results: MetricResultsDict) -> float:
        request_count = metric_results[RequestCountMetric.tag]
        benchmark_duration = metric_results[BenchmarkDurationMetric.tag]
        return request_count / (benchmark_duration / NANOS_PER_SECOND)
Source code in aiperf/metrics/base_derived_metric.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class BaseDerivedMetric(
    Generic[MetricValueTypeVarT], BaseMetric[MetricValueTypeVarT], ABC
):
    """A base class for derived metrics. These metrics are computed from other metrics,
    and do not require any knowledge of the individual records. The final results will be a single computed value (or list of values).

    NOTE: The generic type can be a list of values, or a single value.

    Examples:
    ```python
    class RequestThroughputMetric(BaseDerivedMetric[float]):
        # ... Metric attributes ...

        def _derive_value(self, metric_results: MetricResultsDict) -> float:
            request_count = metric_results[RequestCountMetric.tag]
            benchmark_duration = metric_results[BenchmarkDurationMetric.tag]
            return request_count / (benchmark_duration / NANOS_PER_SECOND)
    ```
    """

    type = MetricType.DERIVED

    def derive_value(self, metric_results: MetricResultsDict) -> MetricValueTypeVarT:
        """Derive the metric value."""
        self._check_metrics(metric_results)
        return self._derive_value(metric_results)

    @abstractmethod
    def _derive_value(self, metric_results: MetricResultsDict) -> MetricValueTypeVarT:
        """Derive the metric value. This method is implemented by subclasses.
        This method is called after the required metrics are checked, so it can assume that the required metrics are available.

        Raises:
            ValueError: If the metric cannot be computed for the given inputs.
        """
        raise NotImplementedError("Subclasses must implement this method")

derive_value(metric_results)

Derive the metric value.

Source code in aiperf/metrics/base_derived_metric.py
33
34
35
36
def derive_value(self, metric_results: MetricResultsDict) -> MetricValueTypeVarT:
    """Derive the metric value."""
    self._check_metrics(metric_results)
    return self._derive_value(metric_results)

aiperf.metrics.base_metric

BaseMetric

Bases: Generic[MetricValueTypeVarT], ABC

A definition of a metric type.

This class is not meant to be instantiated directly or subclassed directly. It is meant to be a common base for all of the base metric classes by type.

The class attributes are: - tag: The tag of the metric. This must be a non-empty string that uniquely identifies the metric type. - header: The header of the metric. This is the user-friendly name of the metric that will be displayed in the UI. - unit: The unit of the internal representation of the metric. This is used for converting to other units and for display. - display_unit: The unit of the metric that is used for display (if different from the unit). None means use the unit for display. - display_order: The display order in the ConsoleExporter. Lower numbers are displayed first. None means unordered after any ordered metrics. - flags: The flags of the metric that determine how and when it is computed and displayed. - required_metrics: The metrics that must be available to compute the metric. This is a set of metric tags.

Source code in aiperf/metrics/base_metric.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class BaseMetric(Generic[MetricValueTypeVarT], ABC):
    """A definition of a metric type.

    This class is not meant to be instantiated directly or subclassed directly.
    It is meant to be a common base for all of the base metric classes by type.

    The class attributes are:
    - tag: The tag of the metric. This must be a non-empty string that uniquely identifies the metric type.
    - header: The header of the metric. This is the user-friendly name of the metric that will be displayed in the UI.
    - unit: The unit of the internal representation of the metric. This is used for converting to other units and for display.
    - display_unit: The unit of the metric that is used for display (if different from the unit). None means use the unit for display.
    - display_order: The display order in the ConsoleExporter. Lower numbers are displayed first. None means unordered after any ordered metrics.
    - flags: The flags of the metric that determine how and when it is computed and displayed.
    - required_metrics: The metrics that must be available to compute the metric. This is a set of metric tags.
    """

    # User-defined attributes to be overridden by subclasses
    tag: ClassVar[MetricTagT]
    header: ClassVar[str] = ""
    unit: ClassVar[MetricUnitT]
    display_unit: ClassVar[MetricUnitT | None] = None
    display_order: ClassVar[int | None] = None
    flags: ClassVar[MetricFlags] = MetricFlags.NONE
    required_metrics: ClassVar[set[MetricTagT] | None] = None

    # Auto-derived attributes
    value_type: ClassVar[MetricValueType]  # Auto set based on generic type parameter
    type: ClassVar[MetricType]  # Set by base subclasses

    def __init_subclass__(cls, **kwargs):
        """
        This method is called when a class is subclassed from Metric.
        It automatically registers the subclass in the MetricRegistry
        dictionary using the `tag` class attribute.
        The `tag` attribute must be a non-empty string that uniquely identifies the
        metric type. Only concrete (non-abstract) classes will be registered.
        """

        super().__init_subclass__(**kwargs)

        # Only register concrete classes (not abstract ones)
        if inspect.isabstract(cls):
            return

        # Verify that the class is a valid metric type
        # Make sure to do this after checking for abstractness, so that the imports are available
        cls._verify_base_class()

        # Import MetricRegistry here to avoid circular imports
        from aiperf.metrics.metric_registry import MetricRegistry

        # Enforce that subclasses define a non-empty tag
        if not cls.tag or not isinstance(cls.tag, str):
            raise TypeError(
                f"Concrete metric class {cls.__name__} must define a non-empty 'tag' class attribute"
            )

        # Auto-detect value type from generic parameter
        cls.value_type = cls._detect_value_type()

        MetricRegistry.register_metric(cls)

    @classmethod
    def _verify_base_class(cls) -> None:
        """Verify that the class is a subclass of BaseRecordMetric, BaseAggregateMetric, or BaseDerivedMetric.
        This is done to ensure that the class is a valid metric type.
        """
        # Note: this is valid because the below imports are abstract, so they will not get here
        from aiperf.metrics import (
            BaseAggregateMetric,
            BaseDerivedMetric,
            BaseRecordMetric,
        )

        # Enforce that concrete subclasses are a subclass of BaseRecordMetric, BaseAggregateMetric, or BaseDerivedMetric
        valid_base_classes = {
            BaseRecordMetric,
            BaseAggregateMetric,
            BaseDerivedMetric,
        }
        if not any(issubclass(cls, base) for base in valid_base_classes):
            raise TypeError(
                f"Concrete metric class {cls.__name__} must be a subclass of BaseRecordMetric, BaseAggregateMetric, or BaseDerivedMetric"
            )

    @classmethod
    def _detect_value_type(cls) -> MetricValueType:
        """Automatically detect the MetricValueType from the generic type parameter."""
        # Look through the class hierarchy for the first Generic[Type] definition
        for base in cls.__orig_bases__:  # type: ignore
            if get_origin(base) is not None:
                args = get_args(base)
                if args:
                    # the first argument is the generic type
                    generic_type = args[0]
                    return MetricValueType.from_python_type(generic_type)

        raise ValueError(
            f"Unable to detect the value type for {cls.__name__}. Please check the generic type parameter."
        )

    def _require_valid_record(self, record: ParsedResponseRecord) -> None:
        """Check that the record is valid."""
        if (not record or not record.valid) and not self.has_flags(
            MetricFlags.ERROR_ONLY
        ):
            raise ValueError("Invalid Record")

    def _check_metrics(self, metrics: MetricRecordDict | MetricResultsDict) -> None:
        """Check that the required metrics are available."""
        if self.required_metrics is None:
            return
        for tag in self.required_metrics:
            if tag not in metrics:
                raise ValueError(f"Missing required metric: '{tag}'")

    @classmethod
    def has_flags(cls, flags: MetricFlags) -> bool:
        """Return True if the metric has the given flag(s) (regardless of other flags)."""
        return cls.flags.has_flags(flags)

    @classmethod
    def missing_flags(cls, flags: MetricFlags) -> bool:
        """Return True if the metric does not have the given flag(s) (regardless of other flags). It will
        return False if the metric has ANY of the given flags."""
        return cls.flags.missing_flags(flags)

__init_subclass__(**kwargs)

This method is called when a class is subclassed from Metric. It automatically registers the subclass in the MetricRegistry dictionary using the tag class attribute. The tag attribute must be a non-empty string that uniquely identifies the metric type. Only concrete (non-abstract) classes will be registered.

Source code in aiperf/metrics/base_metric.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init_subclass__(cls, **kwargs):
    """
    This method is called when a class is subclassed from Metric.
    It automatically registers the subclass in the MetricRegistry
    dictionary using the `tag` class attribute.
    The `tag` attribute must be a non-empty string that uniquely identifies the
    metric type. Only concrete (non-abstract) classes will be registered.
    """

    super().__init_subclass__(**kwargs)

    # Only register concrete classes (not abstract ones)
    if inspect.isabstract(cls):
        return

    # Verify that the class is a valid metric type
    # Make sure to do this after checking for abstractness, so that the imports are available
    cls._verify_base_class()

    # Import MetricRegistry here to avoid circular imports
    from aiperf.metrics.metric_registry import MetricRegistry

    # Enforce that subclasses define a non-empty tag
    if not cls.tag or not isinstance(cls.tag, str):
        raise TypeError(
            f"Concrete metric class {cls.__name__} must define a non-empty 'tag' class attribute"
        )

    # Auto-detect value type from generic parameter
    cls.value_type = cls._detect_value_type()

    MetricRegistry.register_metric(cls)

has_flags(flags) classmethod

Return True if the metric has the given flag(s) (regardless of other flags).

Source code in aiperf/metrics/base_metric.py
139
140
141
142
@classmethod
def has_flags(cls, flags: MetricFlags) -> bool:
    """Return True if the metric has the given flag(s) (regardless of other flags)."""
    return cls.flags.has_flags(flags)

missing_flags(flags) classmethod

Return True if the metric does not have the given flag(s) (regardless of other flags). It will return False if the metric has ANY of the given flags.

Source code in aiperf/metrics/base_metric.py
144
145
146
147
148
@classmethod
def missing_flags(cls, flags: MetricFlags) -> bool:
    """Return True if the metric does not have the given flag(s) (regardless of other flags). It will
    return False if the metric has ANY of the given flags."""
    return cls.flags.missing_flags(flags)

aiperf.metrics.base_record_metric

BaseRecordMetric

Bases: Generic[MetricValueTypeVarT], BaseMetric[MetricValueTypeVarT], ABC

A base class for record-based metrics. These metrics are computed for each record, and are independent of other records. The final results will be a list of values, one for each record.

NOTE: Set the generic type to be the type of the individual values, and NOT a list, unless the metric produces a list for every record. In that case, the result will be a list of lists.

Examples:

class InputSequenceLengthMetric(BaseRecordMetric[int]):
    # ... Metric attributes ...
    # ... Input validation ...

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        return record.input_token_count
Source code in aiperf/metrics/base_record_metric.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class BaseRecordMetric(
    Generic[MetricValueTypeVarT], BaseMetric[MetricValueTypeVarT], ABC
):
    """A base class for record-based metrics. These metrics are computed for each record,
    and are independent of other records. The final results will be a list of values, one for each record.

    NOTE: Set the generic type to be the type of the individual values, and NOT a list, unless the metric produces
    a list *for every record*. In that case, the result will be a list of lists.

    Examples:
    ```python
    class InputSequenceLengthMetric(BaseRecordMetric[int]):
        # ... Metric attributes ...
        # ... Input validation ...

        def _parse_record(
            self,
            record: ParsedResponseRecord,
            record_metrics: MetricRecordDict,
        ) -> int:
            return record.input_token_count
    ```
    """

    type = MetricType.RECORD

    def parse_record(
        self, record: ParsedResponseRecord, record_metrics: MetricRecordDict
    ) -> MetricValueTypeVarT:
        """Parse a single record and return the metric value."""
        self._require_valid_record(record)
        self._check_metrics(record_metrics)
        return self._parse_record(record, record_metrics)

    @abstractmethod
    def _parse_record(
        self, record: ParsedResponseRecord, record_metrics: MetricRecordDict
    ) -> MetricValueTypeVarT:
        """Parse a single record and return the metric value. This method is implemented by subclasses.
        This method is called after the required metrics are checked, so it can assume that the required metrics are available.
        This method is called after the record is checked, so it can assume that the record is valid.

        Raises:
            ValueError: If the metric cannot be computed for the given inputs.
        """
        raise NotImplementedError("Subclasses must implement this method")

parse_record(record, record_metrics)

Parse a single record and return the metric value.

Source code in aiperf/metrics/base_record_metric.py
38
39
40
41
42
43
44
def parse_record(
    self, record: ParsedResponseRecord, record_metrics: MetricRecordDict
) -> MetricValueTypeVarT:
    """Parse a single record and return the metric value."""
    self._require_valid_record(record)
    self._check_metrics(record_metrics)
    return self._parse_record(record, record_metrics)

aiperf.metrics.metric_dicts

MetricRecordDict

Bases: dict[MetricTagT, MetricValueTypeT]

A dict of metrics for a single record. This is used to store the current values of all metrics that have been computed for a single record.

This will include: - The current value of any BaseRecordMetric that has been computed for this record. - The new value of any BaseAggregateMetric that has been computed for this record. - No BaseDerivedMetrics will be included.

Source code in aiperf/metrics/metric_dicts.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class MetricRecordDict(dict[MetricTagT, MetricValueTypeT]):
    """
    A dict of metrics for a single record. This is used to store the current values
    of all metrics that have been computed for a single record.

    This will include:
    - The current value of any `BaseRecordMetric` that has been computed for this record.
    - The new value of any `BaseAggregateMetric` that has been computed for this record.
    - No `BaseDerivedMetric`s will be included.
    """

    def get_converted(
        self, metric: type["BaseMetric"], other_unit: MetricUnitT
    ) -> float:
        """Get the value of a metric, but converted to a different unit."""
        return metric.unit.convert_to(other_unit, self[metric.tag])  # type: ignore

get_converted(metric, other_unit)

Get the value of a metric, but converted to a different unit.

Source code in aiperf/metrics/metric_dicts.py
28
29
30
31
32
def get_converted(
    self, metric: type["BaseMetric"], other_unit: MetricUnitT
) -> float:
    """Get the value of a metric, but converted to a different unit."""
    return metric.unit.convert_to(other_unit, self[metric.tag])  # type: ignore

MetricResultsDict

Bases: dict[MetricTagT, MetricDictValueTypeT]

A dict of metrics over an entire run. This is used to store the final values of all metrics that have been computed for an entire run.

This will include: - All BaseRecordMetrics as a deque of their values. - The most recent value of each BaseAggregateMetric. - The value of any BaseDerivedMetric that has already been computed.

Source code in aiperf/metrics/metric_dicts.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class MetricResultsDict(dict[MetricTagT, MetricDictValueTypeT]):
    """
    A dict of metrics over an entire run. This is used to store the final values
    of all metrics that have been computed for an entire run.

    This will include:
    - All `BaseRecordMetric`s as a deque of their values.
    - The most recent value of each `BaseAggregateMetric`.
    - The value of any `BaseDerivedMetric` that has already been computed.
    """

    def get_converted(
        self, metric: type["BaseMetric"], other_unit: MetricUnitT
    ) -> float:
        """Get the value of a metric, but converted to a different unit."""
        if metric.type == MetricType.RECORD:
            # Record metrics are a deque of values, so we can't convert them directly.
            raise ValueError(
                f"Cannot convert a record metric to a different unit: {metric.tag}"
            )
        return metric.unit.convert_to(other_unit, self[metric.tag])  # type: ignore

get_converted(metric, other_unit)

Get the value of a metric, but converted to a different unit.

Source code in aiperf/metrics/metric_dicts.py
46
47
48
49
50
51
52
53
54
55
def get_converted(
    self, metric: type["BaseMetric"], other_unit: MetricUnitT
) -> float:
    """Get the value of a metric, but converted to a different unit."""
    if metric.type == MetricType.RECORD:
        # Record metrics are a deque of values, so we can't convert them directly.
        raise ValueError(
            f"Cannot convert a record metric to a different unit: {metric.tag}"
        )
    return metric.unit.convert_to(other_unit, self[metric.tag])  # type: ignore

aiperf.metrics.metric_registry

MetricRegistry

A registry for metrics.

This is used to store all the metrics that are available to the system. It is used to lookup metrics by their tag, and to get all the metrics that are available. It also provides methods to get metrics by their type, flag, and to create a dependency order for the metrics. It is also used to create instances of metrics.

This class is not meant to be instantiated directly. It is meant to be used as a singleton via classmethods.

Source code in aiperf/metrics/metric_registry.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
class MetricRegistry:
    """
    A registry for metrics.

    This is used to store all the metrics that are available to the system.
    It is used to lookup metrics by their tag, and to get all the metrics that are available.
    It also provides methods to get metrics by their type, flag, and to create a dependency order for the metrics.
    It is also used to create instances of metrics.

    This class is not meant to be instantiated directly. It is meant to be used as a singleton via classmethods.
    """

    # Map of metric tags to their classes
    _metrics_map: dict[MetricTagT, type["BaseMetric"]] = {}

    # Map of metric tags to their instances
    _instances_map: dict[MetricTagT, "BaseMetric"] = {}
    _instance_lock = Lock()

    @classmethod
    def _discover_metrics(cls) -> None:
        """
        This method dynamically imports all metric type modules from the 'types' directory to ensure
        all metric classes are registered via __init_subclass__. This will be called once when the
        module is imported.
        """
        # Get the types directory path
        types_dir = Path(__file__).parent / "types"

        # Ensure that the types directory exists
        if not types_dir.exists() or not types_dir.is_dir():
            raise MetricTypeError(
                f"Types directory '{types_dir.resolve()}' does not exist or is not a directory"
            )

        # Get the module prefix for the types directory, which is the parent of
        # this module, plus the types directory name.
        # For example, `aiperf.metrics.metric_registry` will become `aiperf.metrics.types`
        module_prefix = ".".join([*cls.__module__.split(".")[:-1], "types"])
        _logger.debug(
            f"Importing metric type modules from '{types_dir.resolve()}' with module prefix '{module_prefix}'"
        )
        # Import all metric type modules to trigger registration
        cls._import_metric_type_modules(types_dir, module_prefix)

    @classmethod
    def _import_metric_type_modules(cls, types_dir: Path, module_prefix: str) -> None:
        """Import all metric type modules from the given directory. This will raise an error if the module cannot be imported."""
        for python_file in types_dir.glob("*.py"):
            if python_file.name != "__init__.py":
                module_name = python_file.stem  # Get filename without extension
                module_path = f"{module_prefix}.{module_name}"
                try:
                    _logger.debug(
                        f"Importing metric type module: '{module_path}' from '{python_file.resolve()}'"
                    )
                    importlib.import_module(module_path)
                except ImportError as err:
                    raise MetricTypeError(
                        f"Error importing metric type module '{module_path}' from '{python_file.resolve()}'"
                    ) from err

    @classmethod
    def register_metric(cls, metric: type["BaseMetric"]):
        """Register a metric class with the registry. This will raise a MetricTypeError if the class is already registered.

        This method is called automatically via the __init_subclass__ method of the BaseMetric class, so there is no need
        to call it manually.
        """
        if metric.tag in cls._metrics_map:
            # TODO: Should we consider adding an override_priority parameter to the metric class similar to AIPerfFactory?
            #       This would allow the user to override built-in metrics with custom implementations, without requiring
            #       them to modify the built-in metric classes.
            raise MetricTypeError(
                f"Metric class with tag {metric.tag} already registered by {cls._metrics_map[metric.tag].__name__}"
            )

        cls._metrics_map[metric.tag] = metric

    @classmethod
    def get_class(cls, tag: MetricTagT) -> type["BaseMetric"]:
        """Get a metric class by its tag.

        Raises:
            MetricTypeError: If the metric class is not found.
        """
        try:
            return cls._metrics_map[tag]
        except KeyError as e:
            raise MetricTypeError(f"Metric class with tag '{tag}' not found") from e

    @classmethod
    def get_instance(cls, tag: MetricTagT) -> "BaseMetric":
        """Get an instance of a metric class by its tag. This will create a new instance if it does not exist.

        Raises:
            MetricTypeError: If the metric class is not found.
        """
        # Check first without acquiring the lock for performance reasons. Since this is a hot path, we want to avoid
        # acquiring the lock if we can. We can do this because we have added a secondary check after acquiring the lock.
        if tag not in cls._instances_map:
            with cls._instance_lock:
                # Check again after acquiring the lock
                if tag not in cls._instances_map:
                    metric_class = cls.get_class(tag)
                    cls._instances_map[tag] = metric_class()
        return cls._instances_map[tag]

    @classmethod
    def tags_applicable_to(
        cls,
        required_flags: MetricFlags,
        disallowed_flags: MetricFlags,
        *types: MetricType,
    ) -> list[MetricTagT]:
        """Get metrics tags that are applicable to the given arguments.

        This method is used to filter the metrics that are applicable to a given set of flags and types.
        For instance, this can be used to only get all DERIVED metrics, or only get metrics that are
        applicable to non-streaming endpoints, etc.

        Arguments:
            required_flags: The flags that the metric must have.
            disallowed_flags: The flags that the metric must not have.
            types: The types of metrics to include. If not provided, all types will be included.

        Returns:
            A list of metric tags that are applicable to the given arguments.
        """
        return [
            tag
            for tag, metric_class in cls._metrics_map.items()
            if metric_class.has_flags(required_flags)
            and metric_class.missing_flags(disallowed_flags)
            and (not types or metric_class.type in types)
        ]

    @classmethod
    def all_tags(cls) -> list[MetricTagT]:
        """Get all of the tags of the defined metric classes."""
        return list(cls._metrics_map.keys())

    @classmethod
    def all_classes(cls) -> list[type["BaseMetric"]]:
        """Get all of the classes of the defined metric classes."""
        return list(cls._metrics_map.values())

    @classmethod
    def classes_for(cls, tags: Iterable[MetricTagT]) -> list[type["BaseMetric"]]:
        """Get the classes for the given tags.

        Raises:
            MetricTypeError: If a tag is not found.
        """
        return [cls.get_class(tag) for tag in tags]

    @classmethod
    def _validate_dependencies(cls) -> None:
        """Validate that all dependencies are registered.

        Raises:
            MetricTypeError: If a dependency is not registered.
        """
        all_tags = cls._metrics_map.keys()
        all_classes = cls._metrics_map.values()

        # Map of metric types to the types of metrics they can have dependencies on
        _allowed_dependencies_by_type = {
            # Record metrics can only depend on other record metrics
            MetricType.RECORD: {MetricType.RECORD},
            # Aggregate metrics can depend on other record or aggregate metrics
            MetricType.AGGREGATE: {MetricType.RECORD, MetricType.AGGREGATE},
            # Derived metrics can depend on any other metric type
            MetricType.DERIVED: {
                MetricType.RECORD,
                MetricType.AGGREGATE,
                MetricType.DERIVED,
            },
        }

        # Validate that all required metrics are registered, and that the dependencies are allowed
        for metric in all_classes:
            for required_tag in metric.required_metrics or set():
                # Validate that the dependency is registered
                if required_tag not in all_tags:
                    raise MetricTypeError(
                        f"Metric '{metric.tag}' depends on '{required_tag}', which is not registered"
                    )

                # Validate that the dependency is allowed
                required_metric_type = cls._metrics_map[required_tag].type
                if (
                    required_metric_type
                    not in _allowed_dependencies_by_type[metric.type]
                ):
                    raise MetricTypeError(
                        f"Metric '{metric.tag}' is a {metric.type} metric, but depends on '{required_tag}', which is a {required_metric_type} metric"
                    )

    @classmethod
    def create_dependency_order(cls) -> list[MetricTagT]:
        """
        Create a dependency order for all available metrics using topological sort.

        See :meth:`create_dependency_order_for` for more details.
        """
        return cls.create_dependency_order_for()

    @classmethod
    def create_dependency_order_for(
        cls, tags: Iterable[MetricTagT] | None = None
    ) -> list[MetricTagT]:
        """
        Create a dependency order for the given metrics using topological sort.

        This ensures that all dependencies are computed before their dependents.
        Note that this will only sort and return the tags that were requested. If a tag
        has a dependency that is not in the list of tags, it will not be included in the order.
        This is useful for cases where we want to sort a subset of metrics that have dependencies
        on other metrics that we know are already computed such as is the case for derived metrics
        that are always computed after all the other metrics.

        Returns:
            List of metric tags in dependency order (dependencies first). Will only
            include tags that were in the requested list.

        Raises:
            MetricTypeError: If there are unregistered dependencies or circular dependencies.
        """
        if tags is None:
            tags = cls._metrics_map.keys()

        # Build the dependency graph
        sorter = graphlib.TopologicalSorter()

        for metric in cls.classes_for(tags):
            # Add the metric with its required dependencies
            sorter.add(metric.tag, *(metric.required_metrics or set()))

        try:
            # Get the dependency order
            order = list(sorter.static_order())

            # Make sure we only return the tags that were requested
            tags_set = set(tags)
            return [tag for tag in order if tag in tags_set]
        except graphlib.CycleError as e:
            raise MetricTypeError(
                f"Circular dependency detected among metrics: {e}"
            ) from e

all_classes() classmethod

Get all of the classes of the defined metric classes.

Source code in aiperf/metrics/metric_registry.py
169
170
171
172
@classmethod
def all_classes(cls) -> list[type["BaseMetric"]]:
    """Get all of the classes of the defined metric classes."""
    return list(cls._metrics_map.values())

all_tags() classmethod

Get all of the tags of the defined metric classes.

Source code in aiperf/metrics/metric_registry.py
164
165
166
167
@classmethod
def all_tags(cls) -> list[MetricTagT]:
    """Get all of the tags of the defined metric classes."""
    return list(cls._metrics_map.keys())

classes_for(tags) classmethod

Get the classes for the given tags.

Raises:

Type Description
MetricTypeError

If a tag is not found.

Source code in aiperf/metrics/metric_registry.py
174
175
176
177
178
179
180
181
@classmethod
def classes_for(cls, tags: Iterable[MetricTagT]) -> list[type["BaseMetric"]]:
    """Get the classes for the given tags.

    Raises:
        MetricTypeError: If a tag is not found.
    """
    return [cls.get_class(tag) for tag in tags]

create_dependency_order() classmethod

Create a dependency order for all available metrics using topological sort.

See :meth:create_dependency_order_for for more details.

Source code in aiperf/metrics/metric_registry.py
226
227
228
229
230
231
232
233
@classmethod
def create_dependency_order(cls) -> list[MetricTagT]:
    """
    Create a dependency order for all available metrics using topological sort.

    See :meth:`create_dependency_order_for` for more details.
    """
    return cls.create_dependency_order_for()

create_dependency_order_for(tags=None) classmethod

Create a dependency order for the given metrics using topological sort.

This ensures that all dependencies are computed before their dependents. Note that this will only sort and return the tags that were requested. If a tag has a dependency that is not in the list of tags, it will not be included in the order. This is useful for cases where we want to sort a subset of metrics that have dependencies on other metrics that we know are already computed such as is the case for derived metrics that are always computed after all the other metrics.

Returns:

Type Description
list[MetricTagT]

List of metric tags in dependency order (dependencies first). Will only

list[MetricTagT]

include tags that were in the requested list.

Raises:

Type Description
MetricTypeError

If there are unregistered dependencies or circular dependencies.

Source code in aiperf/metrics/metric_registry.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
@classmethod
def create_dependency_order_for(
    cls, tags: Iterable[MetricTagT] | None = None
) -> list[MetricTagT]:
    """
    Create a dependency order for the given metrics using topological sort.

    This ensures that all dependencies are computed before their dependents.
    Note that this will only sort and return the tags that were requested. If a tag
    has a dependency that is not in the list of tags, it will not be included in the order.
    This is useful for cases where we want to sort a subset of metrics that have dependencies
    on other metrics that we know are already computed such as is the case for derived metrics
    that are always computed after all the other metrics.

    Returns:
        List of metric tags in dependency order (dependencies first). Will only
        include tags that were in the requested list.

    Raises:
        MetricTypeError: If there are unregistered dependencies or circular dependencies.
    """
    if tags is None:
        tags = cls._metrics_map.keys()

    # Build the dependency graph
    sorter = graphlib.TopologicalSorter()

    for metric in cls.classes_for(tags):
        # Add the metric with its required dependencies
        sorter.add(metric.tag, *(metric.required_metrics or set()))

    try:
        # Get the dependency order
        order = list(sorter.static_order())

        # Make sure we only return the tags that were requested
        tags_set = set(tags)
        return [tag for tag in order if tag in tags_set]
    except graphlib.CycleError as e:
        raise MetricTypeError(
            f"Circular dependency detected among metrics: {e}"
        ) from e

get_class(tag) classmethod

Get a metric class by its tag.

Raises:

Type Description
MetricTypeError

If the metric class is not found.

Source code in aiperf/metrics/metric_registry.py
106
107
108
109
110
111
112
113
114
115
116
@classmethod
def get_class(cls, tag: MetricTagT) -> type["BaseMetric"]:
    """Get a metric class by its tag.

    Raises:
        MetricTypeError: If the metric class is not found.
    """
    try:
        return cls._metrics_map[tag]
    except KeyError as e:
        raise MetricTypeError(f"Metric class with tag '{tag}' not found") from e

get_instance(tag) classmethod

Get an instance of a metric class by its tag. This will create a new instance if it does not exist.

Raises:

Type Description
MetricTypeError

If the metric class is not found.

Source code in aiperf/metrics/metric_registry.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@classmethod
def get_instance(cls, tag: MetricTagT) -> "BaseMetric":
    """Get an instance of a metric class by its tag. This will create a new instance if it does not exist.

    Raises:
        MetricTypeError: If the metric class is not found.
    """
    # Check first without acquiring the lock for performance reasons. Since this is a hot path, we want to avoid
    # acquiring the lock if we can. We can do this because we have added a secondary check after acquiring the lock.
    if tag not in cls._instances_map:
        with cls._instance_lock:
            # Check again after acquiring the lock
            if tag not in cls._instances_map:
                metric_class = cls.get_class(tag)
                cls._instances_map[tag] = metric_class()
    return cls._instances_map[tag]

register_metric(metric) classmethod

Register a metric class with the registry. This will raise a MetricTypeError if the class is already registered.

This method is called automatically via the init_subclass method of the BaseMetric class, so there is no need to call it manually.

Source code in aiperf/metrics/metric_registry.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@classmethod
def register_metric(cls, metric: type["BaseMetric"]):
    """Register a metric class with the registry. This will raise a MetricTypeError if the class is already registered.

    This method is called automatically via the __init_subclass__ method of the BaseMetric class, so there is no need
    to call it manually.
    """
    if metric.tag in cls._metrics_map:
        # TODO: Should we consider adding an override_priority parameter to the metric class similar to AIPerfFactory?
        #       This would allow the user to override built-in metrics with custom implementations, without requiring
        #       them to modify the built-in metric classes.
        raise MetricTypeError(
            f"Metric class with tag {metric.tag} already registered by {cls._metrics_map[metric.tag].__name__}"
        )

    cls._metrics_map[metric.tag] = metric

tags_applicable_to(required_flags, disallowed_flags, *types) classmethod

Get metrics tags that are applicable to the given arguments.

This method is used to filter the metrics that are applicable to a given set of flags and types. For instance, this can be used to only get all DERIVED metrics, or only get metrics that are applicable to non-streaming endpoints, etc.

Parameters:

Name Type Description Default
required_flags MetricFlags

The flags that the metric must have.

required
disallowed_flags MetricFlags

The flags that the metric must not have.

required
types MetricType

The types of metrics to include. If not provided, all types will be included.

()

Returns:

Type Description
list[MetricTagT]

A list of metric tags that are applicable to the given arguments.

Source code in aiperf/metrics/metric_registry.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
@classmethod
def tags_applicable_to(
    cls,
    required_flags: MetricFlags,
    disallowed_flags: MetricFlags,
    *types: MetricType,
) -> list[MetricTagT]:
    """Get metrics tags that are applicable to the given arguments.

    This method is used to filter the metrics that are applicable to a given set of flags and types.
    For instance, this can be used to only get all DERIVED metrics, or only get metrics that are
    applicable to non-streaming endpoints, etc.

    Arguments:
        required_flags: The flags that the metric must have.
        disallowed_flags: The flags that the metric must not have.
        types: The types of metrics to include. If not provided, all types will be included.

    Returns:
        A list of metric tags that are applicable to the given arguments.
    """
    return [
        tag
        for tag, metric_class in cls._metrics_map.items()
        if metric_class.has_flags(required_flags)
        and metric_class.missing_flags(disallowed_flags)
        and (not types or metric_class.type in types)
    ]

aiperf.metrics.types.benchmark_duration_metric

BenchmarkDurationMetric

Bases: BaseDerivedMetric[int]

Post-processor for calculating the Benchmark Duration metric.

Formula

Benchmark Duration = Maximum Response Timestamp - Minimum Request Timestamp

Source code in aiperf/metrics/types/benchmark_duration_metric.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class BenchmarkDurationMetric(BaseDerivedMetric[int]):
    """
    Post-processor for calculating the Benchmark Duration metric.

    Formula:
        Benchmark Duration = Maximum Response Timestamp - Minimum Request Timestamp
    """

    tag = "benchmark_duration"
    header = "Benchmark Duration"
    unit = MetricTimeUnit.NANOSECONDS
    display_unit = MetricTimeUnit.SECONDS
    flags = MetricFlags.HIDDEN
    required_metrics = {
        MinRequestTimestampMetric.tag,
        MaxResponseTimestampMetric.tag,
    }

    def _derive_value(
        self,
        metric_results: MetricResultsDict,
    ) -> int:
        min_req_time = metric_results[MinRequestTimestampMetric.tag]
        max_res_time = metric_results[MaxResponseTimestampMetric.tag]

        if min_req_time is None or max_res_time is None:
            raise ValueError(
                "Min request and max response are required to calculate benchmark duration."
            )

        if min_req_time >= max_res_time:  # type: ignore
            raise ValueError(
                "Min request must be less than max response to calculate benchmark duration."
            )

        return max_res_time - min_req_time  # type: ignore

aiperf.metrics.types.benchmark_token_count

BenchmarkTokenCountMetric

Bases: BaseAggregateMetric[int]

Post-processor for calculating the Benchmark Token Count metric. This is the total number of tokens processed by the benchmark.

Formula

Benchmark Token Count = Sum of Output Sequence Lengths

Source code in aiperf/metrics/types/benchmark_token_count.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class BenchmarkTokenCountMetric(BaseAggregateMetric[int]):
    """
    Post-processor for calculating the Benchmark Token Count metric. This is the total number of tokens processed by the benchmark.

    Formula:
        Benchmark Token Count = Sum of Output Sequence Lengths
    """

    tag = "benchmark_token_count"
    header = "Benchmark Token Count"
    unit = GenericMetricUnit.TOKENS
    flags = (
        MetricFlags.PRODUCES_TOKENS_ONLY
        | MetricFlags.LARGER_IS_BETTER
        | MetricFlags.HIDDEN
    )
    required_metrics = {
        OutputSequenceLengthMetric.tag,
    }

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        # NOTE: We don't need to update the value here, because we are just counting the number of tokens.
        #       The value is updated in the ResultsProcessor via the `_aggregate_value` method.
        return record_metrics[OutputSequenceLengthMetric.tag]  # type: ignore

    def _aggregate_value(self, value: int) -> None:
        """Aggregate the metric value. For this metric, we just sum the values from the different processes."""
        self._value += value

aiperf.metrics.types.error_request_count

ErrorRequestCountMetric

Bases: BaseAggregateMetric[int]

Post-processor for counting the number of error requests.

This metric is only applicable to error records.

Formula

Error Request Count = Sum(Error Requests)

Source code in aiperf/metrics/types/error_request_count.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class ErrorRequestCountMetric(BaseAggregateMetric[int]):
    """
    Post-processor for counting the number of error requests.

    This metric is only applicable to error records.

    Formula:
        Error Request Count = Sum(Error Requests)
    """

    tag = "error_request_count"
    header = "Error Request Count"
    unit = GenericMetricUnit.REQUESTS
    flags = MetricFlags.ERROR_ONLY
    required_metrics = None

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        # We are guaranteed that the record is an error record, so we can return 1.
        return 1

    def _aggregate_value(self, value: int) -> None:
        """Aggregate the metric value. For this metric, we just sum the values from the different processes."""
        self._value += value

aiperf.metrics.types.input_sequence_length_metric

InputSequenceLengthMetric

Bases: BaseRecordMetric[int]

Post-processor for calculating Input Sequence Length (ISL) metrics from records.

Formula

Input Sequence Length = Sum of Input Token Counts

Source code in aiperf/metrics/types/input_sequence_length_metric.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class InputSequenceLengthMetric(BaseRecordMetric[int]):
    """
    Post-processor for calculating Input Sequence Length (ISL) metrics from records.

    Formula:
        Input Sequence Length = Sum of Input Token Counts
    """

    tag = "input_sequence_length"
    header = "Input Sequence Length"
    unit = GenericMetricUnit.TOKENS
    display_order = 700
    flags = MetricFlags.PRODUCES_TOKENS_ONLY | MetricFlags.LARGER_IS_BETTER
    required_metrics = None

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        """
        This method extracts the input token count from the record and returns it.

        Raises:
            ValueError: If the record does not have an input token count.
        """
        if record.input_token_count is None:
            raise ValueError("Input Token Count is not available for the record.")

        return record.input_token_count

aiperf.metrics.types.input_throughput

InputThroughputMetric

Bases: BaseRecordMetric[float]

Post-processor for calculating Input Throughput metrics from records. This is only applicable to streaming responses.

Formula

Input Throughput = Input Sequence Length / Time to First Token (seconds)

Source code in aiperf/metrics/types/input_throughput.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class InputThroughputMetric(BaseRecordMetric[float]):
    """
    Post-processor for calculating Input Throughput metrics from records. This is only applicable to streaming responses.

    Formula:
        Input Throughput = Input Sequence Length / Time to First Token (seconds)
    """

    tag = "input_throughput"
    header = "Input Throughput"
    unit = MetricOverTimeUnit.TOKENS_PER_SECOND
    flags = (
        MetricFlags.STREAMING_TOKENS_ONLY
        | MetricFlags.LARGER_IS_BETTER
        | MetricFlags.HIDDEN  # Hidden for now, as it is new and not yet validated
    )
    required_metrics = {
        InputSequenceLengthMetric.tag,
        TTFTMetric.tag,
    }

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> float:
        """This method calculates the input throughput by dividing the input sequence length by the TTFT."""

        isl = record_metrics[InputSequenceLengthMetric.tag]
        ttft = record_metrics[TTFTMetric.tag]
        if ttft is None or ttft == 0:
            raise ValueError("Time to first token is not available for the record.")
        converted_ttft = record_metrics.get_converted(TTFTMetric, self.unit.time_unit)  # type: ignore
        return isl / converted_ttft  # type: ignore

aiperf.metrics.types.inter_token_latency_metric

InterTokenLatencyMetric

Bases: BaseRecordMetric[float]

Post Processor for calculating Inter Token Latency (ITL) metric.

Formula

Inter Token Latency = (Request Latency - Time to First Token) / (Output Sequence Length - 1)

Source code in aiperf/metrics/types/inter_token_latency_metric.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class InterTokenLatencyMetric(BaseRecordMetric[float]):
    """
    Post Processor for calculating Inter Token Latency (ITL) metric.

    Formula:
        Inter Token Latency = (Request Latency - Time to First Token) / (Output Sequence Length - 1)
    """

    tag = "inter_token_latency"
    header = "Inter Token Latency"
    unit = MetricTimeUnit.NANOSECONDS
    display_unit = MetricTimeUnit.MILLISECONDS
    display_order = 400
    flags = MetricFlags.STREAMING_TOKENS_ONLY | MetricFlags.LARGER_IS_BETTER
    required_metrics = {
        RequestLatencyMetric.tag,
        TTFTMetric.tag,
        OutputSequenceLengthMetric.tag,
    }

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> float:
        """
        Calculates the Inter Token Latency (ITL) metric.
        """
        osl = record_metrics[OutputSequenceLengthMetric.tag]
        if osl is None or osl < 2:
            raise ValueError(f"Output sequence length must be at least 2, got {osl}")

        ttft = record_metrics[TTFTMetric.tag]
        request_latency = record_metrics[RequestLatencyMetric.tag]

        return (request_latency - ttft) / (osl - 1)  # type: ignore

aiperf.metrics.types.max_response_metric

MaxResponseTimestampMetric

Bases: BaseAggregateMetric[int]

Post-processor for calculating the maximum response time stamp metric from records.

Formula

Maximum Response Timestamp = Max(Final Response Timestamps)

Source code in aiperf/metrics/types/max_response_metric.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class MaxResponseTimestampMetric(BaseAggregateMetric[int]):
    """
    Post-processor for calculating the maximum response time stamp metric from records.

    Formula:
        Maximum Response Timestamp = Max(Final Response Timestamps)
    """

    tag = "max_response_timestamp"
    header = "Maximum Response Timestamp"
    unit = MetricTimeUnit.NANOSECONDS
    display_unit = MetricDateTimeUnit.DATE_TIME
    flags = MetricFlags.HIDDEN
    required_metrics = {
        RequestLatencyMetric.tag,
    }

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        """
        Updates the maximum response timestamp metric.
        """
        # Compute the final response timestamp by adding the request latency to the request timestamp.
        # We do this because we want wall-clock timestamps, and the only one we have that is wall-clock
        # time is the timestamp_ns for the start of the request, so we need to use that and work from there.
        request_latency: int = record_metrics[RequestLatencyMetric.tag]  # type: ignore
        final_response_ts = record.timestamp_ns + request_latency
        return final_response_ts

    def _aggregate_value(self, value: int) -> None:
        """Aggregate the metric value. For this metric, we just take the max of the values from the different processes."""
        if value > self._value:
            self._value = value

aiperf.metrics.types.min_request_metric

MinRequestTimestampMetric

Bases: BaseAggregateMetric[int]

Post-processor for calculating the minimum request time stamp metric from records.

Formula

Minimum Request Timestamp = Min(Request Timestamps)

Source code in aiperf/metrics/types/min_request_metric.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class MinRequestTimestampMetric(BaseAggregateMetric[int]):
    """
    Post-processor for calculating the minimum request time stamp metric from records.

    Formula:
        Minimum Request Timestamp = Min(Request Timestamps)
    """

    tag = "min_request_timestamp"
    header = "Minimum Request Timestamp"
    unit = MetricTimeUnit.NANOSECONDS
    display_unit = MetricDateTimeUnit.DATE_TIME
    flags = MetricFlags.HIDDEN
    required_metrics = None

    def __init__(self) -> None:
        # Default to a large value, so that any request timestamp will be smaller.
        super().__init__(default_value=sys.maxsize)

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        """Return the request timestamp."""
        # NOTE: Use the request timestamp_ns, not the start_perf_ns, because we want wall-clock timestamps,
        return record.timestamp_ns

    def _aggregate_value(self, value: int) -> None:
        """Aggregate the metric value. For this metric, we just take the min of the values from the different processes."""
        if value < self._value:
            self._value = value

aiperf.metrics.types.output_sequence_length_metric

OutputSequenceLengthMetric

Bases: BaseRecordMetric[int]

Post-processor for calculating Output Sequence Length (OSL) metrics from records.

Formula

Output Sequence Length = Sum(Output Token Counts)

Source code in aiperf/metrics/types/output_sequence_length_metric.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class OutputSequenceLengthMetric(BaseRecordMetric[int]):
    """
    Post-processor for calculating Output Sequence Length (OSL) metrics from records.

    Formula:
        Output Sequence Length = Sum(Output Token Counts)
    """

    tag = "output_sequence_length"
    header = "Output Sequence Length"
    unit = GenericMetricUnit.TOKENS
    display_order = 600
    flags = MetricFlags.PRODUCES_TOKENS_ONLY | MetricFlags.LARGER_IS_BETTER
    required_metrics = None

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        """
        This method extracts the output token count from the record and returns it.

        Raises:
            ValueError: If the record does not have an output token count.
        """
        if record.output_token_count is None:
            raise ValueError("Output token count is missing in the record.")

        return record.output_token_count

aiperf.metrics.types.output_token_throughput_metric

OutputTokenThroughputMetric

Bases: BaseDerivedMetric[float]

Post Processor for calculating Output Token Throughput Metric.

Formula

Output Token Throughput = Benchmark Token Count / Benchmark Duration (seconds)

Source code in aiperf/metrics/types/output_token_throughput_metric.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class OutputTokenThroughputMetric(BaseDerivedMetric[float]):
    """
    Post Processor for calculating Output Token Throughput Metric.

    Formula:
        Output Token Throughput = Benchmark Token Count / Benchmark Duration (seconds)
    """

    tag = "output_token_throughput"
    header = "Output Token Throughput"
    unit = MetricOverTimeUnit.TOKENS_PER_SECOND
    display_order = 800
    flags = MetricFlags.PRODUCES_TOKENS_ONLY | MetricFlags.LARGER_IS_BETTER
    required_metrics = {
        BenchmarkTokenCountMetric.tag,
        BenchmarkDurationMetric.tag,
    }

    def _derive_value(
        self,
        metric_results: MetricResultsDict,
    ) -> float:
        benchmark_token_count = metric_results[BenchmarkTokenCountMetric.tag]
        benchmark_duration = metric_results[BenchmarkDurationMetric.tag]
        if benchmark_duration is None or benchmark_duration == 0:
            raise ValueError("Benchmark duration is not available.")

        benchmark_duration_converted = metric_results.get_converted(  # type: ignore
            BenchmarkDurationMetric,
            self.unit.time_unit,  # type: ignore
        )
        return benchmark_token_count / benchmark_duration_converted  # type: ignore

aiperf.metrics.types.output_token_throughput_per_user_metric

OutputTokenThroughputPerUserMetric

Bases: BaseRecordMetric[float]

Post Processor for calculating Output Token Throughput Per User Metric.

Formula

Output Token Throughput Per User = 1 / Inter-Token Latency (seconds)

Source code in aiperf/metrics/types/output_token_throughput_per_user_metric.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class OutputTokenThroughputPerUserMetric(BaseRecordMetric[float]):
    """
    Post Processor for calculating Output Token Throughput Per User Metric.

    Formula:
        Output Token Throughput Per User = 1 / Inter-Token Latency (seconds)
    """

    tag = "output_token_throughput_per_user"
    header = "Output Token Throughput Per User\n"
    unit = MetricOverTimeUnit.TOKENS_PER_SECOND_PER_USER
    display_order = 500
    flags = MetricFlags.STREAMING_TOKENS_ONLY | MetricFlags.LARGER_IS_BETTER
    required_metrics = {
        InterTokenLatencyMetric.tag,
    }

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> float:
        """This method calculates the output token throughput per user by computing the inverse of the inter-token latency."""
        itl = record_metrics[InterTokenLatencyMetric.tag]
        if itl is None or itl == 0:
            raise ValueError(
                "Inter-token latency is 0, cannot compute output token throughput per user."
            )
        converted_itl = record_metrics.get_converted(
            InterTokenLatencyMetric,
            self.unit.time_unit,  # type: ignore
        )
        return 1 / converted_itl

aiperf.metrics.types.request_count_metric

RequestCountMetric

Bases: BaseAggregateMetric[int]

Post-processor for counting the number of valid requests.

Formula

Request Count = Sum(Valid Requests)

Source code in aiperf/metrics/types/request_count_metric.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class RequestCountMetric(BaseAggregateMetric[int]):
    """
    Post-processor for counting the number of valid requests.

    Formula:
        Request Count = Sum(Valid Requests)
    """

    tag = "request_count"
    header = "Request Count"
    unit = GenericMetricUnit.REQUESTS
    display_order = 1000
    flags = MetricFlags.LARGER_IS_BETTER
    required_metrics = None

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        # NOTE: We don't need to update the value here, because we are just counting the number of requests.
        #       The value is updated in the ResultsProcessor via the `_aggregate_value` method.
        return 1

    def _aggregate_value(self, value: int) -> None:
        """Aggregate the metric value. For this metric, we just sum the values from the different processes."""
        self._value += value

aiperf.metrics.types.request_latency_metric

RequestLatencyMetric

Bases: BaseRecordMetric[int]

Post-processor for calculating Request Latency metrics from records.

Formula

Request Latency = Final Response Timestamp - Request Start Timestamp

Source code in aiperf/metrics/types/request_latency_metric.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class RequestLatencyMetric(BaseRecordMetric[int]):
    """
    Post-processor for calculating Request Latency metrics from records.

    Formula:
        Request Latency = Final Response Timestamp - Request Start Timestamp
    """

    tag = "request_latency"
    header = "Request Latency"
    unit = MetricTimeUnit.NANOSECONDS
    display_unit = MetricTimeUnit.MILLISECONDS
    display_order = 300
    flags = MetricFlags.NONE
    required_metrics = None

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        """
        This method extracts the request and last response timestamps, and calculates the differences in time.
        """
        request_ts: int = record.start_perf_ns
        final_response_ts: int = record.responses[-1].perf_ns
        return final_response_ts - request_ts

aiperf.metrics.types.request_throughput_metric

RequestThroughputMetric

Bases: BaseDerivedMetric[float]

Post Processor for calculating Request throughput metrics from records.

Formula

Request Throughput = Valid Request Count / Benchmark Duration (seconds)

Source code in aiperf/metrics/types/request_throughput_metric.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class RequestThroughputMetric(BaseDerivedMetric[float]):
    """
    Post Processor for calculating Request throughput metrics from records.

    Formula:
        Request Throughput = Valid Request Count / Benchmark Duration (seconds)
    """

    tag = "request_throughput"
    header = "Request Throughput"
    unit = MetricOverTimeUnit.REQUESTS_PER_SECOND
    display_order = 900
    flags = MetricFlags.LARGER_IS_BETTER
    required_metrics = {
        RequestCountMetric.tag,
        BenchmarkDurationMetric.tag,
    }

    def _derive_value(
        self,
        metric_results: MetricResultsDict,
    ) -> float:
        benchmark_duration = metric_results[BenchmarkDurationMetric.tag]
        if benchmark_duration is None or benchmark_duration == 0:
            raise ValueError(
                "Benchmark duration is required and must be greater than 0 to calculate request throughput."
            )

        request_count = metric_results[RequestCountMetric.tag]
        if request_count is None:
            raise ValueError(
                "Request count is required to calculate request throughput."
            )

        benchmark_duration_converted = metric_results.get_converted(  # type: ignore
            BenchmarkDurationMetric,
            self.unit.time_unit,  # type: ignore
        )
        return request_count / benchmark_duration_converted  # type: ignore

aiperf.metrics.types.ttft_metric

TTFTMetric

Bases: BaseRecordMetric[int]

Post-processor for calculating Time to First Token (TTFT) metrics from records.

Formula

TTFT = First Response Timestamp - Request Start Timestamp

Source code in aiperf/metrics/types/ttft_metric.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class TTFTMetric(BaseRecordMetric[int]):
    """
    Post-processor for calculating Time to First Token (TTFT) metrics from records.

    Formula:
        TTFT = First Response Timestamp - Request Start Timestamp
    """

    tag = "ttft"
    header = "Time to First Token"
    unit = MetricTimeUnit.NANOSECONDS
    display_unit = MetricTimeUnit.MILLISECONDS
    display_order = 100
    flags = MetricFlags.STREAMING_TOKENS_ONLY
    required_metrics = None

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        """
        This method extracts the timestamps from the request start and the first response in the given
        RequestRecord object, computes the difference (TTFT), and returns the result.

        Raises:
            ValueError: If the record does not have at least one response.
        """

        if len(record.responses) < 1:
            raise ValueError(
                "Record must have at least one response to calculate TTFT."
            )

        request_ts: int = record.request.start_perf_ns
        first_response_ts: int = record.responses[0].perf_ns
        if first_response_ts < request_ts:
            raise ValueError(
                "First response timestamp is before request start timestamp, cannot compute TTFT."
            )

        return first_response_ts - request_ts

aiperf.metrics.types.ttst_metric

TTSTMetric

Bases: BaseRecordMetric[int]

Post-processor for calculating Time to Second Token (TTST) metrics from records.

Formula

TTST = Second Response Timestamp - First Response Timestamp

Source code in aiperf/metrics/types/ttst_metric.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class TTSTMetric(BaseRecordMetric[int]):
    """
    Post-processor for calculating Time to Second Token (TTST) metrics from records.

    Formula:
        TTST = Second Response Timestamp - First Response Timestamp
    """

    tag = "ttst"
    header = "Time to Second Token"
    unit = MetricTimeUnit.NANOSECONDS
    display_unit = MetricTimeUnit.MILLISECONDS
    display_order = 200
    flags = MetricFlags.STREAMING_TOKENS_ONLY
    required_metrics = None

    def _parse_record(
        self,
        record: ParsedResponseRecord,
        record_metrics: MetricRecordDict,
    ) -> int:
        """
        This method extracts the timestamps from the first and second response in the given
        RequestRecord object, computes the difference (TTST), and returns the result.

        Raises:
            ValueError: If the record does not have at least two responses, or if the second response is before the first response.
        """

        if len(record.responses) < 2:
            raise ValueError(
                "Record must have at least two responses to calculate TTST."
            )
        if record.responses[1].perf_ns < record.responses[0].perf_ns:
            raise ValueError(
                "Second response timestamp must be greater than or equal to the first response timestamp."
            )

        first_response_ts: int = record.responses[0].perf_ns
        second_response_ts: int = record.responses[1].perf_ns
        return second_response_ts - first_response_ts

aiperf.module_loader

Module loader for AIPerf.

This module is used to load all modules into the system to ensure everything is registered and ready to be used. This is done to avoid the performance penalty of importing all modules during CLI startup, while still ensuring that all implementations are properly registered with their factories.

ensure_modules_loaded()

Ensure all modules are loaded exactly once.

Source code in aiperf/module_loader.py
47
48
49
50
51
52
53
54
55
56
57
58
def ensure_modules_loaded() -> None:
    """Ensure all modules are loaded exactly once."""
    global _modules_loaded
    with _modules_loaded_lock:
        if not _modules_loaded:
            start_time = time.perf_counter()
            _logger.debug("Loading all modules")
            _load_all_modules()
            _logger.debug(
                f"Modules loaded in {time.perf_counter() - start_time:.2f} seconds"
            )
            _modules_loaded = True

aiperf.parsers.inference_result_parser

InferenceResultParser

Bases: CommunicationMixin

InferenceResultParser is responsible for parsing the inference results.

Source code in aiperf/parsers/inference_result_parser.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class InferenceResultParser(CommunicationMixin):
    """InferenceResultParser is responsible for parsing the inference results."""

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
        )
        self.conversation_request_client: RequestClientProtocol = (
            self.comms.create_request_client(
                CommAddress.DATASET_MANAGER_PROXY_FRONTEND,
            )
        )
        self.tokenizers: dict[str, Tokenizer] = {}
        self.user_config: UserConfig = user_config
        self.tokenizer_lock: asyncio.Lock = asyncio.Lock()
        self.model_endpoint: ModelEndpointInfo = ModelEndpointInfo.from_user_config(
            user_config
        )
        self.extractor: ResponseExtractorProtocol = (
            ResponseExtractorFactory.create_instance(
                self.model_endpoint.endpoint.type,
                model_endpoint=self.model_endpoint,
            )
        )

    @on_init
    async def _initialize(self) -> None:
        """Initialize inference result parser-specific components."""
        self.debug("Initializing inference result parser")

        self.extractor = ResponseExtractorFactory.create_instance(
            self.model_endpoint.endpoint.type,
            model_endpoint=self.model_endpoint,
        )

    async def configure(self) -> None:
        """Configure the tokenizers."""
        self.info("Configuring tokenizers for inference result parser")
        begin = time.perf_counter()
        async with self.tokenizer_lock:
            self.tokenizers = {
                model.name: Tokenizer.from_pretrained(
                    self.user_config.tokenizer.name or model.name,
                    trust_remote_code=self.user_config.tokenizer.trust_remote_code,
                    revision=self.user_config.tokenizer.revision,
                )
                for model in self.model_endpoint.models.models
            }
        duration = time.perf_counter() - begin
        tokenizer_info = {
            model: {
                "class": tokenizer._tokenizer.__class__.__name__,
                "name_or_path": getattr(tokenizer._tokenizer, "name_or_path", ""),
            }
            for model, tokenizer in self.tokenizers.items()
        }
        self.info(f"Initialized tokenizers: {tokenizer_info} in {duration:.2f} seconds")

    async def get_tokenizer(self, model: str) -> Tokenizer:
        """Get the tokenizer for a given model."""
        async with self.tokenizer_lock:
            if model not in self.tokenizers:
                self.tokenizers[model] = Tokenizer.from_pretrained(
                    self.user_config.tokenizer.name or model,
                    trust_remote_code=self.user_config.tokenizer.trust_remote_code,
                    revision=self.user_config.tokenizer.revision,
                )
            return self.tokenizers[model]

    async def parse_request_record(
        self, request_record: RequestRecord
    ) -> ParsedResponseRecord:
        """Handle an inference results message."""
        self.trace_or_debug(
            lambda: f"Received inference results message: {request_record}",
            lambda: "Received inference results",
        )

        if request_record.has_error:
            return ParsedResponseRecord(
                request=request_record,
                responses=[],
            )

        elif request_record.valid:
            try:
                record = await self.process_valid_record(request_record)
                self.debug(
                    lambda: f"Received {len(record.request.responses)} responses, input_token_count: {record.input_token_count}, output_token_count: {record.output_token_count}"
                )
                return record
            except Exception as e:
                # TODO: We should add an ErrorDetails to the response record and not the request record.
                self.exception(f"Error processing valid record: {e}")
                request_record.error = ErrorDetails.from_exception(e)
                return ParsedResponseRecord(
                    request=request_record,
                    responses=[],
                )
        else:
            self.warning(f"Received invalid inference results: {request_record}")
            # TODO: We should add an ErrorDetails to response record and not the request record.
            request_record.error = ErrorDetails(
                code=None,
                message="Invalid inference results",
                type="InvalidInferenceResults",
            )
            return ParsedResponseRecord(
                request=request_record,
                responses=[],
            )

    async def process_valid_record(
        self, request_record: RequestRecord
    ) -> ParsedResponseRecord:
        """Process a valid request record."""
        if request_record.model_name is None:
            self.warning(
                lambda: f"Model name is None, unable to process record: {request_record}"
            )
            return ParsedResponseRecord(
                request=request_record,
                responses=[],
                input_token_count=None,
                output_token_count=None,
            )

        tokenizer = await self.get_tokenizer(request_record.model_name)
        resp = await self.extractor.extract_response_data(request_record, tokenizer)
        input_token_count = await self.compute_input_token_count(
            request_record, tokenizer
        )
        output_token_count = sum(
            response.token_count
            for response in resp
            if response.token_count is not None
        )

        return ParsedResponseRecord(
            request=request_record,
            responses=resp,
            input_token_count=input_token_count,
            output_token_count=output_token_count,
        )

    async def compute_input_token_count(
        self, request_record: RequestRecord, tokenizer: Tokenizer
    ) -> int | None:
        """Compute the number of tokens in the input for a given request record."""
        if request_record.conversation_id is None or request_record.turn_index is None:
            self.warning(
                lambda: f"Conversation ID or turn index is None: {request_record.conversation_id=} {request_record.turn_index=}"
            )
            return None

        turn_response: ConversationTurnResponseMessage = (
            await self.conversation_request_client.request(
                ConversationTurnRequestMessage(
                    service_id=self.id,
                    conversation_id=request_record.conversation_id,
                    turn_index=request_record.turn_index,
                )
            )
        )
        if isinstance(turn_response, ErrorMessage):
            self.error(lambda: f"Error getting turn response: {turn_response}")
            return None

        turn = turn_response.turn
        return sum(
            len(tokenizer.encode(content))
            for text in turn.texts
            for content in text.contents
        )

compute_input_token_count(request_record, tokenizer) async

Compute the number of tokens in the input for a given request record.

Source code in aiperf/parsers/inference_result_parser.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
async def compute_input_token_count(
    self, request_record: RequestRecord, tokenizer: Tokenizer
) -> int | None:
    """Compute the number of tokens in the input for a given request record."""
    if request_record.conversation_id is None or request_record.turn_index is None:
        self.warning(
            lambda: f"Conversation ID or turn index is None: {request_record.conversation_id=} {request_record.turn_index=}"
        )
        return None

    turn_response: ConversationTurnResponseMessage = (
        await self.conversation_request_client.request(
            ConversationTurnRequestMessage(
                service_id=self.id,
                conversation_id=request_record.conversation_id,
                turn_index=request_record.turn_index,
            )
        )
    )
    if isinstance(turn_response, ErrorMessage):
        self.error(lambda: f"Error getting turn response: {turn_response}")
        return None

    turn = turn_response.turn
    return sum(
        len(tokenizer.encode(content))
        for text in turn.texts
        for content in text.contents
    )

configure() async

Configure the tokenizers.

Source code in aiperf/parsers/inference_result_parser.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
async def configure(self) -> None:
    """Configure the tokenizers."""
    self.info("Configuring tokenizers for inference result parser")
    begin = time.perf_counter()
    async with self.tokenizer_lock:
        self.tokenizers = {
            model.name: Tokenizer.from_pretrained(
                self.user_config.tokenizer.name or model.name,
                trust_remote_code=self.user_config.tokenizer.trust_remote_code,
                revision=self.user_config.tokenizer.revision,
            )
            for model in self.model_endpoint.models.models
        }
    duration = time.perf_counter() - begin
    tokenizer_info = {
        model: {
            "class": tokenizer._tokenizer.__class__.__name__,
            "name_or_path": getattr(tokenizer._tokenizer, "name_or_path", ""),
        }
        for model, tokenizer in self.tokenizers.items()
    }
    self.info(f"Initialized tokenizers: {tokenizer_info} in {duration:.2f} seconds")

get_tokenizer(model) async

Get the tokenizer for a given model.

Source code in aiperf/parsers/inference_result_parser.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
async def get_tokenizer(self, model: str) -> Tokenizer:
    """Get the tokenizer for a given model."""
    async with self.tokenizer_lock:
        if model not in self.tokenizers:
            self.tokenizers[model] = Tokenizer.from_pretrained(
                self.user_config.tokenizer.name or model,
                trust_remote_code=self.user_config.tokenizer.trust_remote_code,
                revision=self.user_config.tokenizer.revision,
            )
        return self.tokenizers[model]

parse_request_record(request_record) async

Handle an inference results message.

Source code in aiperf/parsers/inference_result_parser.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
async def parse_request_record(
    self, request_record: RequestRecord
) -> ParsedResponseRecord:
    """Handle an inference results message."""
    self.trace_or_debug(
        lambda: f"Received inference results message: {request_record}",
        lambda: "Received inference results",
    )

    if request_record.has_error:
        return ParsedResponseRecord(
            request=request_record,
            responses=[],
        )

    elif request_record.valid:
        try:
            record = await self.process_valid_record(request_record)
            self.debug(
                lambda: f"Received {len(record.request.responses)} responses, input_token_count: {record.input_token_count}, output_token_count: {record.output_token_count}"
            )
            return record
        except Exception as e:
            # TODO: We should add an ErrorDetails to the response record and not the request record.
            self.exception(f"Error processing valid record: {e}")
            request_record.error = ErrorDetails.from_exception(e)
            return ParsedResponseRecord(
                request=request_record,
                responses=[],
            )
    else:
        self.warning(f"Received invalid inference results: {request_record}")
        # TODO: We should add an ErrorDetails to response record and not the request record.
        request_record.error = ErrorDetails(
            code=None,
            message="Invalid inference results",
            type="InvalidInferenceResults",
        )
        return ParsedResponseRecord(
            request=request_record,
            responses=[],
        )

process_valid_record(request_record) async

Process a valid request record.

Source code in aiperf/parsers/inference_result_parser.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
async def process_valid_record(
    self, request_record: RequestRecord
) -> ParsedResponseRecord:
    """Process a valid request record."""
    if request_record.model_name is None:
        self.warning(
            lambda: f"Model name is None, unable to process record: {request_record}"
        )
        return ParsedResponseRecord(
            request=request_record,
            responses=[],
            input_token_count=None,
            output_token_count=None,
        )

    tokenizer = await self.get_tokenizer(request_record.model_name)
    resp = await self.extractor.extract_response_data(request_record, tokenizer)
    input_token_count = await self.compute_input_token_count(
        request_record, tokenizer
    )
    output_token_count = sum(
        response.token_count
        for response in resp
        if response.token_count is not None
    )

    return ParsedResponseRecord(
        request=request_record,
        responses=resp,
        input_token_count=input_token_count,
        output_token_count=output_token_count,
    )

aiperf.parsers.openai_parsers

OpenAIObject

Bases: CaseInsensitiveStrEnum

Types of OpenAI objects.

Source code in aiperf/parsers/openai_parsers.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class OpenAIObject(CaseInsensitiveStrEnum):
    """Types of OpenAI objects."""

    CHAT_COMPLETION = "chat.completion"
    CHAT_COMPLETION_CHUNK = "chat.completion.chunk"
    COMPLETION = "completion"
    EMBEDDING = "embedding"
    LIST = "list"
    RESPONSE = "response"
    TEXT_COMPLETION = "text_completion"

    @classmethod
    def parse(cls, text: str) -> BaseModel:
        """Attempt to parse a string into an OpenAI object.

        Raises:
            ValueError: If the object is invalid.
        """
        try:
            obj = load_json_str(text)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid OpenAI object: {text}") from e

        # Mapping of OpenAI object types to their corresponding Pydantic models.
        _object_mapping: dict[str, type[BaseModel]] = {
            cls.CHAT_COMPLETION: ChatCompletion,
            cls.CHAT_COMPLETION_CHUNK: ChatCompletionChunk,
            cls.COMPLETION: Completion,
            cls.RESPONSE: ResponsesModel,
            cls.TEXT_COMPLETION: Completion,  # Alias for vLLM compatibility
        }

        obj_type = obj.get("object")
        if obj_type is None:
            raise ValueError(f"Invalid OpenAI object: {obj}")
        if obj_type == cls.LIST:
            return cls.parse_list(obj)
        if obj_type not in _object_mapping:
            raise ValueError(f"Invalid OpenAI object type: {obj_type}")
        try:
            # Hotfix: vLLM does not always include a finish_reason, which Pydantic requires.
            # Without this code, model_validate will raise an objection due to the missing finish_reason.
            if obj_type == cls.TEXT_COMPLETION:
                for choice in obj.get("choices", []):
                    if choice.get("finish_reason") is None:
                        choice["finish_reason"] = "stop"
            return _object_mapping[obj_type].model_validate(obj)
        except Exception as e:
            raise ValueError(f"Invalid OpenAI object: {text}") from e

    @classmethod
    def parse_list(cls, obj: Any) -> BaseModel:
        """Attempt to parse a string into an OpenAI object from a list.

        Raises:
            ValueError: If the object is invalid.
        """
        data = obj.get("data", [])
        if all(
            isinstance(item, dict) and item.get("object") == cls.EMBEDDING
            for item in data
        ):
            return CreateEmbeddingResponse.model_validate(obj)
        else:
            raise ValueError(f"Receive invalid list in response: {obj}")

parse(text) classmethod

Attempt to parse a string into an OpenAI object.

Raises:

Type Description
ValueError

If the object is invalid.

Source code in aiperf/parsers/openai_parsers.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@classmethod
def parse(cls, text: str) -> BaseModel:
    """Attempt to parse a string into an OpenAI object.

    Raises:
        ValueError: If the object is invalid.
    """
    try:
        obj = load_json_str(text)
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid OpenAI object: {text}") from e

    # Mapping of OpenAI object types to their corresponding Pydantic models.
    _object_mapping: dict[str, type[BaseModel]] = {
        cls.CHAT_COMPLETION: ChatCompletion,
        cls.CHAT_COMPLETION_CHUNK: ChatCompletionChunk,
        cls.COMPLETION: Completion,
        cls.RESPONSE: ResponsesModel,
        cls.TEXT_COMPLETION: Completion,  # Alias for vLLM compatibility
    }

    obj_type = obj.get("object")
    if obj_type is None:
        raise ValueError(f"Invalid OpenAI object: {obj}")
    if obj_type == cls.LIST:
        return cls.parse_list(obj)
    if obj_type not in _object_mapping:
        raise ValueError(f"Invalid OpenAI object type: {obj_type}")
    try:
        # Hotfix: vLLM does not always include a finish_reason, which Pydantic requires.
        # Without this code, model_validate will raise an objection due to the missing finish_reason.
        if obj_type == cls.TEXT_COMPLETION:
            for choice in obj.get("choices", []):
                if choice.get("finish_reason") is None:
                    choice["finish_reason"] = "stop"
        return _object_mapping[obj_type].model_validate(obj)
    except Exception as e:
        raise ValueError(f"Invalid OpenAI object: {text}") from e

parse_list(obj) classmethod

Attempt to parse a string into an OpenAI object from a list.

Raises:

Type Description
ValueError

If the object is invalid.

Source code in aiperf/parsers/openai_parsers.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@classmethod
def parse_list(cls, obj: Any) -> BaseModel:
    """Attempt to parse a string into an OpenAI object from a list.

    Raises:
        ValueError: If the object is invalid.
    """
    data = obj.get("data", [])
    if all(
        isinstance(item, dict) and item.get("object") == cls.EMBEDDING
        for item in data
    ):
        return CreateEmbeddingResponse.model_validate(obj)
    else:
        raise ValueError(f"Receive invalid list in response: {obj}")

OpenAIResponseExtractor

Extractor for OpenAI responses.

Source code in aiperf/parsers/openai_parsers.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
@ResponseExtractorFactory.register_all(
    EndpointType.OPENAI_CHAT_COMPLETIONS,
    EndpointType.OPENAI_COMPLETIONS,
    EndpointType.OPENAI_EMBEDDINGS,
    EndpointType.OPENAI_RESPONSES,
)
class OpenAIResponseExtractor:
    """Extractor for OpenAI responses."""

    def __init__(self, model_endpoint: ModelEndpointInfo) -> None:
        """Create a new response extractor based on the provided configuration."""
        self.model_endpoint = model_endpoint

    def _parse_text_response(self, response: TextResponse) -> ResponseData | None:
        """Parse a TextResponse into a ResponseData object."""
        raw = response.text
        parsed = self._parse_text(raw)
        if parsed is None:
            return None

        return ResponseData(
            perf_ns=response.perf_ns,
            raw_text=[raw],
            parsed_text=[parsed],
            metadata={},
        )

    def _parse_sse_response(self, response: SSEMessage) -> ResponseData | None:
        """Parse a SSEMessage into a ResponseData object."""
        raw = response.extract_data_content()
        parsed = self._parse_sse(raw)
        if parsed is None or len(parsed) == 0:
            return None

        return ResponseData(
            perf_ns=response.perf_ns,
            raw_text=raw,
            parsed_text=parsed,
            metadata={},
        )

    def _parse_response(self, response: InferenceServerResponse) -> ResponseData | None:
        """Parse a response into a ResponseData object."""
        if isinstance(response, TextResponse):
            return self._parse_text_response(response)
        elif isinstance(response, SSEMessage):
            return self._parse_sse_response(response)

    async def extract_response_data(
        self, record: RequestRecord, tokenizer: Tokenizer | None
    ) -> list[ResponseData]:
        """Extract the text from a server response message."""
        results = []
        for response in record.responses:
            response_data = self._parse_response(response)
            if response_data is None:
                continue

            if tokenizer is not None:
                response_data.token_count = sum(
                    len(tokenizer.encode(text))
                    for text in response_data.parsed_text
                    if text is not None
                )
            results.append(response_data)
        return results

    def _parse_text(self, raw_text: str) -> Any | None:
        """Parse the text of the response."""
        if raw_text in ("", None, "[DONE]"):
            return None

        obj = OpenAIObject.parse(raw_text)

        # Dictionary mapping object types to their value extraction functions
        type_to_extractor = {
            # TODO: how to support multiple choices?
            ChatCompletion: lambda obj: obj.choices[0].message.content,
            # TODO: how to support multiple choices?
            ChatCompletionChunk: lambda obj: obj.choices[0].delta.content,
            # TODO: how to support multiple choices?
            Completion: lambda obj: obj.choices[0].text,
            CreateEmbeddingResponse: lambda obj: "",  # Don't store embedding data
            ResponsesModel: lambda obj: obj.output_text,
        }

        for obj_type, extractor in type_to_extractor.items():
            if isinstance(obj, obj_type):
                return extractor(obj)

        raise ValueError(f"Invalid OpenAI object: {raw_text}")

    def _parse_sse(self, raw_sse: list[str]) -> list[Any]:
        """Parse the SSE of the response."""
        result = []
        for sse in raw_sse:
            parsed = self._parse_text(sse)
            if parsed is None:
                continue
            result.append(parsed)
        return result

__init__(model_endpoint)

Create a new response extractor based on the provided configuration.

Source code in aiperf/parsers/openai_parsers.py
104
105
106
def __init__(self, model_endpoint: ModelEndpointInfo) -> None:
    """Create a new response extractor based on the provided configuration."""
    self.model_endpoint = model_endpoint

extract_response_data(record, tokenizer) async

Extract the text from a server response message.

Source code in aiperf/parsers/openai_parsers.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
async def extract_response_data(
    self, record: RequestRecord, tokenizer: Tokenizer | None
) -> list[ResponseData]:
    """Extract the text from a server response message."""
    results = []
    for response in record.responses:
        response_data = self._parse_response(response)
        if response_data is None:
            continue

        if tokenizer is not None:
            response_data.token_count = sum(
                len(tokenizer.encode(text))
                for text in response_data.parsed_text
                if text is not None
            )
        results.append(response_data)
    return results

aiperf.post_processors.base_metrics_processor

BaseMetricsProcessor

Bases: AIPerfLoggerMixin, ABC

Base class for all metrics processors. This class is responsible for filtering the metrics based on the user config.

Source code in aiperf/post_processors/base_metrics_processor.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class BaseMetricsProcessor(AIPerfLoggerMixin, ABC):
    """Base class for all metrics processors. This class is responsible for filtering the metrics based on the user config."""

    def __init__(self, user_config: UserConfig, **kwargs):
        self.user_config = user_config
        super().__init__(user_config=user_config, **kwargs)

    def get_filters(self) -> tuple[MetricFlags, MetricFlags]:
        """Get the filters for the metrics based on the user config.
        Returns:
            tuple[MetricFlags, MetricFlags]: The required and disallowed flags.
        """
        # Start with no flags (unfiltered)
        required_flags, disallowed_flags = MetricFlags.NONE, MetricFlags.NONE
        # Disable metrics that are not applicable to the endpoint type
        if not self.user_config.endpoint.type.produces_tokens:
            disallowed_flags |= MetricFlags.PRODUCES_TOKENS_ONLY
        if not self.user_config.endpoint.type.supports_audio:
            disallowed_flags |= MetricFlags.SUPPORTS_AUDIO_ONLY
        if not self.user_config.endpoint.type.supports_images:
            disallowed_flags |= MetricFlags.SUPPORTS_IMAGE_ONLY
        if not self.user_config.endpoint.streaming:
            disallowed_flags |= MetricFlags.STREAMING_ONLY
        return required_flags, disallowed_flags

    def _setup_metrics(
        self,
        *metric_types: MetricType,
        error_metrics_only: bool = False,
        exclude_error_metrics: bool = False,
    ) -> list[BaseMetric]:
        """Get an ordered list of metrics that are applicable to the endpoint type and user config.
        The metrics are ordered based on their dependencies, ensuring proper computation order.

        Be sure to compute the metrics sequentially versus in parallel, as some metrics may depend on the results of previous metrics.
        """
        required_flags, disallowed_flags = self.get_filters()
        if error_metrics_only:
            required_flags |= MetricFlags.ERROR_ONLY
        elif exclude_error_metrics:
            disallowed_flags |= MetricFlags.ERROR_ONLY

        metrics: list[BaseMetric] = []
        supported_tags = MetricRegistry.tags_applicable_to(
            required_flags,
            disallowed_flags,
            *metric_types,
        )
        ordered_tags = MetricRegistry.create_dependency_order_for(supported_tags)
        for metric_tag in ordered_tags:
            metrics.append(MetricRegistry.get_instance(metric_tag))
        return metrics

get_filters()

Get the filters for the metrics based on the user config. Returns: tuple[MetricFlags, MetricFlags]: The required and disallowed flags.

Source code in aiperf/post_processors/base_metrics_processor.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def get_filters(self) -> tuple[MetricFlags, MetricFlags]:
    """Get the filters for the metrics based on the user config.
    Returns:
        tuple[MetricFlags, MetricFlags]: The required and disallowed flags.
    """
    # Start with no flags (unfiltered)
    required_flags, disallowed_flags = MetricFlags.NONE, MetricFlags.NONE
    # Disable metrics that are not applicable to the endpoint type
    if not self.user_config.endpoint.type.produces_tokens:
        disallowed_flags |= MetricFlags.PRODUCES_TOKENS_ONLY
    if not self.user_config.endpoint.type.supports_audio:
        disallowed_flags |= MetricFlags.SUPPORTS_AUDIO_ONLY
    if not self.user_config.endpoint.type.supports_images:
        disallowed_flags |= MetricFlags.SUPPORTS_IMAGE_ONLY
    if not self.user_config.endpoint.streaming:
        disallowed_flags |= MetricFlags.STREAMING_ONLY
    return required_flags, disallowed_flags

aiperf.post_processors.metric_record_processor

MetricRecordProcessor

Bases: BaseMetricsProcessor

Processor for metric records.

This is the first stage of the metrics processing pipeline, and is done is a distributed manner across multiple service instances. It is responsible for streaming the records to the post processor, and computing the metrics from the records. It computes metrics from MetricType.RECORD and MetricType.AGGREGATE types.

Source code in aiperf/post_processors/metric_record_processor.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@implements_protocol(RecordProcessorProtocol)
@RecordProcessorFactory.register(RecordProcessorType.METRIC_RECORD)
class MetricRecordProcessor(BaseMetricsProcessor):
    """Processor for metric records.

    This is the first stage of the metrics processing pipeline, and is done is a distributed manner across multiple service instances.
    It is responsible for streaming the records to the post processor, and computing the metrics from the records.
    It computes metrics from MetricType.RECORD and MetricType.AGGREGATE types."""

    def __init__(
        self,
        user_config: UserConfig,
        **kwargs,
    ) -> None:
        super().__init__(user_config=user_config, **kwargs)

        # Store a reference to the parse_record function for valid metrics.
        # This is done to avoid extra attribute lookups.
        self.valid_parse_funcs: list[
            tuple[MetricTagT, Callable[[ParsedResponseRecord, MetricRecordDict], Any]]
        ] = [
            (metric.tag, metric.parse_record)  # type: ignore
            for metric in self._setup_metrics(
                MetricType.RECORD, MetricType.AGGREGATE, exclude_error_metrics=True
            )
        ]

        # Store a reference to the parse_record function for error metrics.
        # This is done to avoid extra attribute lookups.
        self.error_parse_funcs: list[
            tuple[MetricTagT, Callable[[ParsedResponseRecord, MetricRecordDict], Any]]
        ] = [
            (metric.tag, metric.parse_record)  # type: ignore
            for metric in self._setup_metrics(
                MetricType.RECORD, MetricType.AGGREGATE, error_metrics_only=True
            )
        ]

    async def process_record(self, record: ParsedResponseRecord) -> MetricRecordDict:
        """Process a response record from the inference results parser."""
        record_metrics: MetricRecordDict = MetricRecordDict()
        parse_funcs = self.valid_parse_funcs if record.valid else self.error_parse_funcs
        # NOTE: Need to parse the record in a loop, as the parse_record function may depend on the results of previous metrics.
        for tag, parse_func in parse_funcs:
            try:
                record_metrics[tag] = parse_func(record, record_metrics)
            except Exception as e:
                self.warning(f"Error parsing record for metric '{tag}': {e}")
        return record_metrics

process_record(record) async

Process a response record from the inference results parser.

Source code in aiperf/post_processors/metric_record_processor.py
56
57
58
59
60
61
62
63
64
65
66
async def process_record(self, record: ParsedResponseRecord) -> MetricRecordDict:
    """Process a response record from the inference results parser."""
    record_metrics: MetricRecordDict = MetricRecordDict()
    parse_funcs = self.valid_parse_funcs if record.valid else self.error_parse_funcs
    # NOTE: Need to parse the record in a loop, as the parse_record function may depend on the results of previous metrics.
    for tag, parse_func in parse_funcs:
        try:
            record_metrics[tag] = parse_func(record, record_metrics)
        except Exception as e:
            self.warning(f"Error parsing record for metric '{tag}': {e}")
    return record_metrics

aiperf.post_processors.metric_results_processor

MetricResultsProcessor

Bases: BaseMetricsProcessor

Processor for metric results.

This is the final stage of the metrics processing pipeline, and is done is a unified manner by the RecordsManager. It is responsible for processing the results and returning them to the RecordsManager, as well as summarizing the results.

Source code in aiperf/post_processors/metric_results_processor.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
@implements_protocol(ResultsProcessorProtocol)
@ResultsProcessorFactory.register(ResultsProcessorType.METRIC_RESULTS)
class MetricResultsProcessor(BaseMetricsProcessor):
    """Processor for metric results.

    This is the final stage of the metrics processing pipeline, and is done is a unified manner by the RecordsManager.
    It is responsible for processing the results and returning them to the RecordsManager, as well as summarizing the results.
    """

    def __init__(self, user_config: UserConfig, **kwargs: Any):
        super().__init__(user_config=user_config, **kwargs)
        # For derived metrics, we don't care about splitting up the error metrics
        self.derive_funcs: dict[
            MetricTagT, Callable[[MetricResultsDict], MetricValueTypeT]
        ] = {
            metric.tag: metric.derive_value  # type: ignore
            for metric in self._setup_metrics(MetricType.DERIVED)
        }

        # Create the results dict, which will be used to store the results of non-derived metrics,
        # and then be updated with the derived metrics.
        self._results: MetricResultsDict = MetricResultsDict()

        # Get all of the metric classes.
        _all_metric_classes: list[type[BaseMetric]] = MetricRegistry.all_classes()

        # Pre-cache the types for the metrics.
        self._tags_to_types: dict[MetricTagT, MetricType] = {
            metric.tag: metric.type for metric in _all_metric_classes
        }

        # Pre-cache the instances for the metrics.
        self._instances_map: dict[MetricTagT, BaseMetric] = {
            tag: MetricRegistry.get_instance(tag) for tag in MetricRegistry.all_tags()
        }

        # Pre-cache the aggregate functions for the aggregate metrics.
        self._tags_to_aggregate_funcs: dict[
            MetricTagT, Callable[[MetricResultsDict], MetricValueTypeT]
        ] = {
            metric.tag: MetricRegistry.get_instance(metric.tag).aggregate_value  # type: ignore
            for metric in _all_metric_classes
            if metric.type == MetricType.AGGREGATE
        }

    async def process_result(self, incoming_metrics: MetricRecordDict) -> None:
        """Process a result from the metric record processor."""
        if self.is_trace_enabled:
            self.trace(f"Processing incoming metrics: {incoming_metrics}")

        for tag, value in incoming_metrics.items():
            try:
                metric_type = self._tags_to_types[tag]
                if metric_type == MetricType.RECORD:
                    if tag not in self._results:
                        self._results[tag] = deque()
                    self._results[tag].append(value)  # type: ignore

                elif metric_type == MetricType.AGGREGATE:
                    metric: BaseAggregateMetric = self._instances_map[tag]  # type: ignore
                    metric.aggregate_value(value)
                    self._results[tag] = metric.current_value

                else:
                    raise ValueError(f"Metric '{tag}' is not a valid metric type")
            except Exception as e:
                self.warning(f"Error processing metric '{tag}': {e}")

        if self.is_trace_enabled:
            self.trace(f"Results after processing incoming metrics: {self._results}")

    async def summarize(self) -> list[MetricResult]:
        """Summarize the results.

        This will compute the values for the derived metrics, and then create the MetricResult objects for each metric.
        """
        # Compute the values for the derived metrics, and store them in the results dict.
        for tag, derive_func in self.derive_funcs.items():
            self._results[tag] = derive_func(self._results)

        # Compute and return the metric results.
        return [
            self._create_metric_result(tag, values)
            for tag, values in self._results.items()
        ]

    def _create_metric_result(
        self, tag: MetricTagT, values: MetricDictValueTypeT
    ) -> MetricResult:
        """Create a MetricResult from a the current values of a metric."""

        metric_class = self._instances_map[tag]

        if isinstance(values, int | float):
            return MetricResult(
                tag=metric_class.tag,
                header=metric_class.header,
                unit=str(metric_class.unit),
                avg=values,
                count=1,
            )

        if isinstance(values, Iterable):
            series = pd.Series(values, dtype=metric_class.value_type.dtype)
            quantiles = series.quantile(
                [0.01, 0.05, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99]
            )
            return MetricResult(
                tag=metric_class.tag,
                header=metric_class.header,
                unit=str(metric_class.unit),
                avg=series.mean(),
                min=series.min(),
                max=series.max(),
                p1=quantiles[0.01],
                p5=quantiles[0.05],
                p25=quantiles[0.25],
                p50=quantiles[0.50],
                p75=quantiles[0.75],
                p90=quantiles[0.90],
                p95=quantiles[0.95],
                p99=quantiles[0.99],
                std=series.std(),
                count=len(series),
            )

        raise ValueError(f"Unexpected values type: {type(values)}")

process_result(incoming_metrics) async

Process a result from the metric record processor.

Source code in aiperf/post_processors/metric_results_processor.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
async def process_result(self, incoming_metrics: MetricRecordDict) -> None:
    """Process a result from the metric record processor."""
    if self.is_trace_enabled:
        self.trace(f"Processing incoming metrics: {incoming_metrics}")

    for tag, value in incoming_metrics.items():
        try:
            metric_type = self._tags_to_types[tag]
            if metric_type == MetricType.RECORD:
                if tag not in self._results:
                    self._results[tag] = deque()
                self._results[tag].append(value)  # type: ignore

            elif metric_type == MetricType.AGGREGATE:
                metric: BaseAggregateMetric = self._instances_map[tag]  # type: ignore
                metric.aggregate_value(value)
                self._results[tag] = metric.current_value

            else:
                raise ValueError(f"Metric '{tag}' is not a valid metric type")
        except Exception as e:
            self.warning(f"Error processing metric '{tag}': {e}")

    if self.is_trace_enabled:
        self.trace(f"Results after processing incoming metrics: {self._results}")

summarize() async

Summarize the results.

This will compute the values for the derived metrics, and then create the MetricResult objects for each metric.

Source code in aiperf/post_processors/metric_results_processor.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
async def summarize(self) -> list[MetricResult]:
    """Summarize the results.

    This will compute the values for the derived metrics, and then create the MetricResult objects for each metric.
    """
    # Compute the values for the derived metrics, and store them in the results dict.
    for tag, derive_func in self.derive_funcs.items():
        self._results[tag] = derive_func(self._results)

    # Compute and return the metric results.
    return [
        self._create_metric_result(tag, values)
        for tag, values in self._results.items()
    ]

aiperf.records.record_processor_service

RecordProcessor

Bases: PullClientMixin, BaseComponentService

RecordProcessor is responsible for processing the records and pushing them to the RecordsManager. This service is meant to be run in a distributed fashion, where the amount of record processors can be scaled based on the load of the system.

Source code in aiperf/records/record_processor_service.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@ServiceFactory.register(ServiceType.RECORD_PROCESSOR)
class RecordProcessor(PullClientMixin, BaseComponentService):
    """RecordProcessor is responsible for processing the records and pushing them to the RecordsManager.
    This service is meant to be run in a distributed fashion, where the amount of record processors can be scaled
    based on the load of the system.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            pull_client_address=CommAddress.RAW_INFERENCE_PROXY_BACKEND,
            pull_client_bind=False,
            pull_client_max_concurrency=DEFAULT_PULL_CLIENT_MAX_CONCURRENCY,
        )
        self.records_push_client: PushClientProtocol = self.comms.create_push_client(
            CommAddress.RECORDS,
        )
        self.conversation_request_client: RequestClientProtocol = (
            self.comms.create_request_client(
                CommAddress.DATASET_MANAGER_PROXY_FRONTEND,
            )
        )
        self.tokenizers: dict[str, Tokenizer] = {}
        self.user_config: UserConfig = user_config
        self.tokenizer_lock: asyncio.Lock = asyncio.Lock()
        self.model_endpoint: ModelEndpointInfo = ModelEndpointInfo.from_user_config(
            user_config
        )
        self.inference_result_parser = InferenceResultParser(
            service_config=service_config,
            user_config=user_config,
        )
        self.records_processors: list[RecordProcessorProtocol] = []

    @on_init
    async def _initialize(self) -> None:
        """Initialize record processor-specific components."""
        self.debug("Initializing record processor")

        self.extractor = ResponseExtractorFactory.create_instance(
            self.model_endpoint.endpoint.type,
            model_endpoint=self.model_endpoint,
        )

        # Initialize all the records streamers
        for processor_type in RecordProcessorFactory.get_all_class_types():
            self.records_processors.append(
                RecordProcessorFactory.create_instance(
                    processor_type,
                    service_config=self.service_config,
                    user_config=self.user_config,
                )
            )

    @on_command(CommandType.PROFILE_CONFIGURE)
    async def _profile_configure_command(
        self, message: ProfileConfigureCommand
    ) -> None:
        """Configure the tokenizers."""
        await self.inference_result_parser.configure()

    async def get_tokenizer(self, model: str) -> Tokenizer:
        """Get the tokenizer for a given model."""
        async with self.tokenizer_lock:
            if model not in self.tokenizers:
                self.tokenizers[model] = Tokenizer.from_pretrained(
                    self.user_config.tokenizer.name or model,
                    trust_remote_code=self.user_config.tokenizer.trust_remote_code,
                    revision=self.user_config.tokenizer.revision,
                )
            return self.tokenizers[model]

    @on_pull_message(MessageType.INFERENCE_RESULTS)
    async def _on_inference_results(self, message: InferenceResultsMessage) -> None:
        """Handle an inference results message."""
        parsed_record = await self.inference_result_parser.parse_request_record(
            message.record
        )
        raw_results = await self._process_record(parsed_record)
        results = []
        for result in raw_results:
            if isinstance(result, BaseException):
                self.warning(f"Error processing record: {result}")
            else:
                results.append(result)
        await self.records_push_client.push(
            MetricRecordsMessage(
                service_id=self.service_id,
                worker_id=message.service_id,
                credit_phase=message.record.credit_phase,
                results=results,
                error=message.record.error,
            )
        )

    async def _process_record(
        self, record: ParsedResponseRecord
    ) -> list[MetricRecordDict | BaseException]:
        """Stream a record to the records processors."""
        tasks = [
            processor.process_record(record) for processor in self.records_processors
        ]
        results: list[MetricRecordDict | BaseException] = await asyncio.gather(
            *tasks, return_exceptions=True
        )
        return results

get_tokenizer(model) async

Get the tokenizer for a given model.

Source code in aiperf/records/record_processor_service.py
105
106
107
108
109
110
111
112
113
114
async def get_tokenizer(self, model: str) -> Tokenizer:
    """Get the tokenizer for a given model."""
    async with self.tokenizer_lock:
        if model not in self.tokenizers:
            self.tokenizers[model] = Tokenizer.from_pretrained(
                self.user_config.tokenizer.name or model,
                trust_remote_code=self.user_config.tokenizer.trust_remote_code,
                revision=self.user_config.tokenizer.revision,
            )
        return self.tokenizers[model]

aiperf.records.records_manager

RecordsManager

Bases: PullClientMixin, BaseComponentService

The RecordsManager service is primarily responsible for holding the results returned from the workers.

Source code in aiperf/records/records_manager.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
@implements_protocol(ServiceProtocol)
@ServiceFactory.register(ServiceType.RECORDS_MANAGER)
class RecordsManager(PullClientMixin, BaseComponentService):
    """
    The RecordsManager service is primarily responsible for holding the
    results returned from the workers.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            pull_client_address=CommAddress.RECORDS,
            pull_client_bind=True,
            pull_client_max_concurrency=DEFAULT_PULL_CLIENT_MAX_CONCURRENCY,
        )
        self._profile_cancelled = False

        self.start_time_ns: int | None = None
        self.processing_stats: PhaseProcessingStats = PhaseProcessingStats()
        self.final_request_count: int | None = None
        self.end_time_ns: int | None = None
        self.error_summary: dict[ErrorDetails, int] = {}
        # Track per-worker statistics
        self.worker_stats: dict[str, PhaseProcessingStats] = {}

        self._results_processors: list[ResultsProcessorProtocol] = []
        for results_processor_type in ResultsProcessorFactory.get_all_class_types():
            results_processor = ResultsProcessorFactory.create_instance(
                class_type=results_processor_type,
                service_id=self.service_id,
                service_config=self.service_config,
                user_config=self.user_config,
            )
            self.debug(
                f"Created results processor: {results_processor_type}: {results_processor.__class__.__name__}"
            )
            self._results_processors.append(results_processor)

    @on_pull_message(MessageType.METRIC_RECORDS)
    async def _on_metric_records(self, message: MetricRecordsMessage) -> None:
        """Handle a metric records message."""
        self.trace(lambda: f"Received metric records: {message}")

        if message.credit_phase != CreditPhase.PROFILING:
            self.debug(lambda: f"Skipping non-profiling record: {message.credit_phase}")
            return

        worker_id = message.worker_id
        if worker_id not in self.worker_stats:
            self.worker_stats[worker_id] = PhaseProcessingStats()

        if message.valid:
            self.worker_stats[worker_id].processed += 1
            self.processing_stats.processed += 1
        else:
            self.worker_stats[worker_id].errors += 1
            self.processing_stats.errors += 1
            if message.error:
                self.error_summary[message.error] = (
                    self.error_summary.get(message.error, 0) + 1
                )

        await self._send_results_to_results_processors(message.results)

        if (
            self.final_request_count is not None
            and self.processing_stats.total_records >= self.final_request_count
        ):
            self.info(
                lambda: f"Processed {self.processing_stats.processed} valid requests and {self.processing_stats.errors} errors ({self.processing_stats.total_records} total)."
            )
            # Make sure everyone knows the final stats, including the worker stats
            await self._publish_processing_stats()

            # Send a message to the event bus to signal that we received all the records
            await self.publish(
                AllRecordsReceivedMessage(
                    service_id=self.service_id,
                    request_ns=time.time_ns(),
                    final_processing_stats=self.processing_stats,
                )
            )

            self.debug(lambda: f"Received all records: {message}, processing now...")
            await self._process_records(cancelled=self._profile_cancelled)

    async def _send_results_to_results_processors(
        self, results: list[dict[MetricTagT, MetricValueTypeT]]
    ) -> None:
        """Send the results to each of the results processors."""
        await asyncio.gather(
            *[
                results_processor.process_result(result)
                for results_processor in self._results_processors
                for result in results
            ]
        )

    @on_message(MessageType.CREDIT_PHASE_START)
    async def _on_credit_phase_start(
        self, phase_start_msg: CreditPhaseStartMessage
    ) -> None:
        """Handle a credit phase start message in order to track the total number of expected requests."""
        if phase_start_msg.phase == CreditPhase.PROFILING:
            self.start_time_ns = phase_start_msg.start_ns or time.time_ns()
            self.processing_stats.total_expected_requests = (
                phase_start_msg.total_expected_requests
            )

    @on_message(MessageType.CREDIT_PHASE_COMPLETE)
    async def _on_credit_phase_complete(
        self, phase_complete_msg: CreditPhaseCompleteMessage
    ) -> None:
        """Handle a credit phase complete message in order to track the final request count."""
        if phase_complete_msg.phase != CreditPhase.PROFILING:
            return
        # This will equate to how many records we expect to receive,
        # and once we receive that many records, we know to stop.
        self.final_request_count = phase_complete_msg.completed
        self.end_time_ns = phase_complete_msg.end_ns or time.time_ns()
        self.info(f"Updating final request count to {self.final_request_count}")
        self.notice(
            f"All requests have completed, please wait for the results to be processed (currently {self.processing_stats.total_records} of {self.final_request_count} records processed)..."
        )
        if self.final_request_count == self.processing_stats.total_records:
            await self._process_records(cancelled=False)

    @background_task(
        interval=lambda self: self.service_config.progress_report_interval,
        immediate=False,
    )
    async def _report_records_task(self) -> None:
        """Report the records processing stats."""
        if self.processing_stats.processed > 0 or self.processing_stats.errors > 0:
            # Only publish stats if there are records to report
            await self._publish_processing_stats()

    async def _publish_processing_stats(self) -> None:
        """Publish the profile processing stats."""
        await self.publish(
            RecordsProcessingStatsMessage(
                service_id=self.service_id,
                request_ns=time.time_ns(),
                processing_stats=self.processing_stats,
                worker_stats=self.worker_stats,
            ),
        )

    @on_command(CommandType.PROCESS_RECORDS)
    async def _on_process_records_command(
        self, message: ProcessRecordsCommand
    ) -> ProcessRecordsResult:
        """Handle the process records command by forwarding it to all of the results processors, and returning the results."""
        self.debug(lambda: f"Received process records command: {message}")
        return await self._process_records(cancelled=message.cancelled)

    @on_command(CommandType.PROFILE_CANCEL)
    async def _on_profile_cancel_command(
        self, message: ProfileCancelCommand
    ) -> ProcessRecordsResult:
        """Handle the profile cancel command by cancelling the streaming post processors."""
        self.debug(lambda: f"Received profile cancel command: {message}")
        self._profile_cancelled = True
        return await self._process_records(cancelled=True)

    async def _process_records(self, cancelled: bool) -> ProcessRecordsResult:
        """Process the records."""
        self.debug(lambda: f"Processing records (cancelled: {cancelled})")

        self.info("Processing records results...")
        # Process the records through the results processors.
        results = await asyncio.gather(
            *[
                results_processor.summarize()
                for results_processor in self._results_processors
            ],
            return_exceptions=True,
        )

        records_results, error_results = [], []
        for result in results:
            if isinstance(result, list):
                records_results.extend(result)
            elif isinstance(result, ErrorDetails):
                error_results.append(result)
            elif isinstance(result, BaseException):
                error_results.append(ErrorDetails.from_exception(result))

        result = ProcessRecordsResult(
            results=ProfileResults(
                records=records_results,
                completed=len(records_results),
                start_ns=self.start_time_ns or time.time_ns(),
                end_ns=self.end_time_ns or time.time_ns(),
                error_summary=self.get_error_summary(),
                was_cancelled=cancelled,
            ),
            errors=error_results,
        )
        self.debug(lambda: f"Process records result: {result}")
        await self.publish(
            ProcessRecordsResultMessage(
                service_id=self.service_id,
                results=result,
            )
        )
        return result

    def get_error_summary(self) -> list[ErrorDetailsCount]:
        """Generate a summary of the error records."""
        return [
            ErrorDetailsCount(error_details=error_details, count=count)
            for error_details, count in self.error_summary.items()
        ]

get_error_summary()

Generate a summary of the error records.

Source code in aiperf/records/records_manager.py
260
261
262
263
264
265
def get_error_summary(self) -> list[ErrorDetailsCount]:
    """Generate a summary of the error records."""
    return [
        ErrorDetailsCount(error_details=error_details, count=count)
        for error_details, count in self.error_summary.items()
    ]

main()

Main entry point for the records manager.

Source code in aiperf/records/records_manager.py
268
269
270
271
272
273
def main() -> None:
    """Main entry point for the records manager."""

    from aiperf.common.bootstrap import bootstrap_and_run_service

    bootstrap_and_run_service(RecordsManager)

aiperf.timing.concurrency_strategy

ConcurrencyStrategy

Bases: CreditIssuingStrategy

Class for concurrency credit issuing strategy.

Source code in aiperf/timing/concurrency_strategy.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@CreditIssuingStrategyFactory.register(TimingMode.CONCURRENCY)
class ConcurrencyStrategy(CreditIssuingStrategy):
    """Class for concurrency credit issuing strategy."""

    def __init__(
        self, config: TimingManagerConfig, credit_manager: CreditManagerProtocol
    ):
        super().__init__(config=config, credit_manager=credit_manager)

        # If the concurrency is larger than the total number of requests, it does not matter
        # as it is simply an upper bound that will never be reached
        self._semaphore = asyncio.Semaphore(value=config.concurrency)

    async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None:
        """Execute a single credit phase. This will not return until the phase sending is complete."""
        if phase_stats.is_time_based:
            await self._execute_time_based_phase(phase_stats)
        elif phase_stats.is_request_count_based:
            await self._execute_request_count_based_phase(phase_stats)
        else:
            raise InvalidStateError(
                "Phase must have either a valid total or expected_duration_ns set"
            )

    async def _execute_time_based_phase(self, phase_stats: CreditPhaseStats) -> None:
        """Execute a time-based phase."""

        # Start the internal loop in a task so that we can cancel it when the time expires
        time_task = asyncio.create_task(
            self._execute_time_based_phase_internal(phase_stats)
        )

        # Calculate how long until the phase expires
        sleep_time_sec = (
            (phase_stats.start_ns / NANOS_PER_SECOND)  # type: ignore
            + phase_stats.expected_duration_sec
            - time.time()
        )
        self.trace(
            lambda: f"Time-based phase will expire in {sleep_time_sec} seconds: {phase_stats}"
        )

        # Sleep until the phase expires, and then cancel the task
        await asyncio.sleep(sleep_time_sec)
        time_task.cancel()
        self.debug(lambda: f"Time-based phase execution expired: {phase_stats}")
        # Note, not awaiting the task here as we do not want to block moving to the next phase

    async def _execute_time_based_phase_internal(
        self, phase_stats: CreditPhaseStats
    ) -> None:
        """Execute a the internal loop for a time-based phase. This will be called within a task and cancelled when the time expires."""

        self.trace(
            lambda: f"_execute_time_based_phase_internal loop entered: {phase_stats}"
        )

        # This will loop until the task is cancelled
        while True:
            try:
                # Acquire the semaphore. Once we hit the concurrency limit, this will block until a credit is returned
                await self._semaphore.acquire()
                self.execute_async(
                    self.credit_manager.drop_credit(
                        credit_phase=phase_stats.type,
                    )
                )
                phase_stats.sent += 1
            except asyncio.CancelledError:
                self.trace(
                    lambda: f"_execute_time_based_phase_internal loop exited: {phase_stats}"
                )
                self.debug("Time-based phase execution expired")
                break

    async def _execute_request_count_based_phase(
        self, phase_stats: CreditPhaseStats
    ) -> None:
        self.trace(
            lambda: f"_execute_request_count_based_phase loop entered: {phase_stats}"
        )

        total: int = phase_stats.total_expected_requests  # type: ignore

        while phase_stats.sent < total:
            await self._semaphore.acquire()
            self.execute_async(
                self.credit_manager.drop_credit(
                    credit_phase=phase_stats.type,
                )
            )
            phase_stats.sent += 1

        self.trace(
            lambda: f"_execute_request_count_based_phase loop exited: {phase_stats}"
        )

    async def _on_credit_return(self, message: CreditReturnMessage) -> None:
        """Process a credit return message."""

        # Release the semaphore to allow another credit to be issued,
        # then call the superclass to handle the credit return like normal
        self._semaphore.release()
        self.trace(lambda: f"Credit return released semaphore: {self._semaphore}")
        await super()._on_credit_return(message)

aiperf.timing.config

TimingManagerConfig

Bases: AIPerfBaseModel

Configuration for the timing manager.

Source code in aiperf/timing/config.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class TimingManagerConfig(AIPerfBaseModel):
    """Configuration for the timing manager."""

    timing_mode: TimingMode = LoadGeneratorDefaults.TIMING_MODE
    concurrency: int = LoadGeneratorDefaults.CONCURRENCY
    request_rate: float | None = LoadGeneratorDefaults.REQUEST_RATE
    request_rate_mode: RequestRateMode = LoadGeneratorDefaults.REQUEST_RATE_MODE
    request_count: int = LoadGeneratorDefaults.REQUEST_COUNT
    warmup_request_count: int = LoadGeneratorDefaults.WARMUP_REQUEST_COUNT
    random_seed: int | None = None
    progress_report_interval_sec: float = ServiceDefaults.PROGRESS_REPORT_INTERVAL

    @classmethod
    def from_user_config(cls, user_config: UserConfig) -> "TimingManagerConfig":
        """Create a TimingManagerConfig from a UserConfig."""

        if user_config.input.fixed_schedule:
            timing_mode = TimingMode.FIXED_SCHEDULE
        elif user_config.loadgen.request_rate is not None:
            timing_mode = TimingMode.REQUEST_RATE
        else:
            # Default to concurrency mode if no request rate or schedule is provided
            timing_mode = TimingMode.CONCURRENCY

        return cls(
            timing_mode=timing_mode,
            concurrency=user_config.loadgen.concurrency,
            request_rate=user_config.loadgen.request_rate,
            request_rate_mode=user_config.loadgen.request_rate_mode,
            request_count=user_config.loadgen.request_count,
            warmup_request_count=user_config.loadgen.warmup_request_count,
            random_seed=user_config.input.random_seed,
        )

from_user_config(user_config) classmethod

Create a TimingManagerConfig from a UserConfig.

Source code in aiperf/timing/config.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@classmethod
def from_user_config(cls, user_config: UserConfig) -> "TimingManagerConfig":
    """Create a TimingManagerConfig from a UserConfig."""

    if user_config.input.fixed_schedule:
        timing_mode = TimingMode.FIXED_SCHEDULE
    elif user_config.loadgen.request_rate is not None:
        timing_mode = TimingMode.REQUEST_RATE
    else:
        # Default to concurrency mode if no request rate or schedule is provided
        timing_mode = TimingMode.CONCURRENCY

    return cls(
        timing_mode=timing_mode,
        concurrency=user_config.loadgen.concurrency,
        request_rate=user_config.loadgen.request_rate,
        request_rate_mode=user_config.loadgen.request_rate_mode,
        request_count=user_config.loadgen.request_count,
        warmup_request_count=user_config.loadgen.warmup_request_count,
        random_seed=user_config.input.random_seed,
    )

aiperf.timing.credit_issuing_strategy

CreditIssuingStrategy

Bases: TaskManagerMixin, ABC

Base class for credit issuing strategies.

Source code in aiperf/timing/credit_issuing_strategy.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class CreditIssuingStrategy(TaskManagerMixin, ABC):
    """
    Base class for credit issuing strategies.
    """

    def __init__(
        self, config: TimingManagerConfig, credit_manager: CreditManagerProtocol
    ):
        super().__init__()
        self.config = config
        self.credit_manager = credit_manager

        # This event is set when all phases are complete
        self.all_phases_complete_event = asyncio.Event()

        # The running stats for each phase, keyed by phase type.
        self.phase_stats: dict[CreditPhase, CreditPhaseStats] = {}

        # The phases to run including their configuration, in order of execution.
        self.ordered_phase_configs: list[CreditPhaseConfig] = []

        self._setup_phase_configs()
        self._validate_phase_configs()

    def _setup_phase_configs(self) -> None:
        """Setup the phases for the strategy. This can be overridden in subclasses to modify the phases."""
        self._setup_warmup_phase_config()
        self._setup_profiling_phase_config()
        self.info(
            lambda: f"Credit issuing strategy {self.__class__.__name__} initialized with {len(self.ordered_phase_configs)} "
            f"phase(s): {self.ordered_phase_configs}"
        )

    def _setup_warmup_phase_config(self) -> None:
        """Setup the warmup phase. This can be overridden in subclasses to modify the warmup phase."""
        if self.config.warmup_request_count > 0:
            self.ordered_phase_configs.append(
                CreditPhaseConfig(
                    type=CreditPhase.WARMUP,
                    total_expected_requests=self.config.warmup_request_count,
                )
            )

    def _setup_profiling_phase_config(self) -> None:
        """Setup the profiling phase. This can be overridden in subclasses to modify the profiling phase."""
        self.ordered_phase_configs.append(
            CreditPhaseConfig(
                type=CreditPhase.PROFILING,
                total_expected_requests=self.config.request_count,
            )
        )

    def _validate_phase_configs(self) -> None:
        """Validate the phase configs."""
        for phase_config in self.ordered_phase_configs:
            if not phase_config.is_valid:
                raise ConfigurationError(
                    f"Phase {phase_config.type} is not valid. It must have either a valid total_expected_requests or expected_duration_sec set"
                )

    async def start(self) -> None:
        """Start the credit issuing strategy. This will launch the progress reporting loop, the
        warmup phase (if applicable), and the profiling phase, all in the background."""
        self.debug(
            lambda: f"Starting credit issuing strategy {self.__class__.__name__}"
        )
        self.all_phases_complete_event.clear()

        # Start the progress reporting loop in the background
        self.execute_async(self._progress_report_loop())

        # Execute the phases in the background
        self.execute_async(self._execute_phases())

        self.debug(
            lambda: f"Waiting for all credit phases to complete for {self.__class__.__name__}"
        )
        # Wait for all phases to complete before returning
        await self.all_phases_complete_event.wait()
        self.debug(lambda: f"All credit phases completed for {self.__class__.__name__}")

    async def _execute_phases(self) -> None:
        """Execute the all of the credit phases sequentially. This can be overridden in subclasses to modify the execution of the phases."""
        for phase_config in self.ordered_phase_configs:
            phase_stats = CreditPhaseStats.from_phase_config(phase_config)
            phase_stats.start_ns = time.time_ns()
            self.phase_stats[phase_config.type] = phase_stats

            self.execute_async(
                self.credit_manager.publish_phase_start(
                    phase_config.type,
                    phase_stats.start_ns,
                    # Only one of the below will be set, this is already validated in the strategy
                    phase_config.total_expected_requests,
                    phase_config.expected_duration_sec,
                )
            )

            # This is implemented in subclasses
            await self._execute_single_phase(phase_stats)

            # We have sent all the credits for this phase. We must continue to the next
            # phase even though not all the credits have been returned. This is because
            # we do not want a gap in the credit issuing.
            phase_stats.sent_end_ns = time.time_ns()
            self.execute_async(
                self.credit_manager.publish_phase_sending_complete(
                    phase_config.type, phase_stats.sent_end_ns
                )
            )

    @abstractmethod
    async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None:
        """Execute a single phase. Should not return until the phase sending is complete. Must be implemented in subclasses."""
        raise NotImplementedError("Subclasses must implement this method")

    async def stop(self) -> None:
        """Stop the credit issuing strategy."""
        await self.cancel_all_tasks()

    async def _on_credit_return(self, message: CreditReturnMessage) -> None:
        """This is called by the credit manager when a credit is returned. It can be
        overridden in subclasses to handle the credit return."""
        if message.phase not in self.phase_stats:
            # self.warning(
            #     lambda: f"Credit return message received for phase {message.phase} but no phase stats found"
            # )
            return

        phase_stats = self.phase_stats[message.phase]
        phase_stats.completed += 1

        if (
            # If we have sent all the credits, check if this is the last one to be returned
            phase_stats.is_sending_complete
            and phase_stats.completed >= phase_stats.total_expected_requests  # type: ignore[operator]
        ):
            phase_stats.end_ns = time.time_ns()
            self.info(lambda: f"Phase completed: {phase_stats}")

            self.execute_async(
                self.credit_manager.publish_phase_complete(
                    message.phase, phase_stats.completed, phase_stats.end_ns
                )
            )

            if phase_stats.type == CreditPhase.PROFILING:
                self.execute_async(self.credit_manager.publish_credits_complete())
                self.all_phases_complete_event.set()

            # We don't need to keep track of the phase stats anymore
            self.notice(
                lambda: f"Phase {message.phase} completed, removing phase stats"
            )
            self.phase_stats.pop(message.phase)

    async def _progress_report_loop(self) -> None:
        """Report the progress at a fixed interval."""
        self.debug("Starting progress reporting loop")
        while not self.all_phases_complete_event.is_set():
            await asyncio.sleep(self.config.progress_report_interval_sec)

            for phase, stats in self.phase_stats.items():
                try:
                    await self.credit_manager.publish_progress(
                        phase, stats.sent, stats.completed
                    )
                except Exception as e:
                    self.error(f"Error publishing credit progress: {e}")
                except asyncio.CancelledError:
                    self.debug("Credit progress reporting loop cancelled")
                    return

        self.debug("All credits completed, stopping credit progress reporting loop")

start() async

Start the credit issuing strategy. This will launch the progress reporting loop, the warmup phase (if applicable), and the profiling phase, all in the background.

Source code in aiperf/timing/credit_issuing_strategy.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
async def start(self) -> None:
    """Start the credit issuing strategy. This will launch the progress reporting loop, the
    warmup phase (if applicable), and the profiling phase, all in the background."""
    self.debug(
        lambda: f"Starting credit issuing strategy {self.__class__.__name__}"
    )
    self.all_phases_complete_event.clear()

    # Start the progress reporting loop in the background
    self.execute_async(self._progress_report_loop())

    # Execute the phases in the background
    self.execute_async(self._execute_phases())

    self.debug(
        lambda: f"Waiting for all credit phases to complete for {self.__class__.__name__}"
    )
    # Wait for all phases to complete before returning
    await self.all_phases_complete_event.wait()
    self.debug(lambda: f"All credit phases completed for {self.__class__.__name__}")

stop() async

Stop the credit issuing strategy.

Source code in aiperf/timing/credit_issuing_strategy.py
134
135
136
async def stop(self) -> None:
    """Stop the credit issuing strategy."""
    await self.cancel_all_tasks()

CreditIssuingStrategyFactory

Bases: AIPerfFactory[TimingMode, CreditIssuingStrategy]

Factory for creating credit issuing strategies based on the timing mode.

Source code in aiperf/timing/credit_issuing_strategy.py
194
195
class CreditIssuingStrategyFactory(AIPerfFactory[TimingMode, CreditIssuingStrategy]):
    """Factory for creating credit issuing strategies based on the timing mode."""

aiperf.timing.credit_manager

CreditManagerProtocol

Bases: PubClientProtocol, Protocol

Defines the interface for a CreditManager.

This is used to allow the credit issuing strategy to interact with the TimingManager in a decoupled way.

Source code in aiperf/timing/credit_manager.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@runtime_checkable
class CreditManagerProtocol(PubClientProtocol, Protocol):
    """Defines the interface for a CreditManager.

    This is used to allow the credit issuing strategy to interact with the TimingManager
    in a decoupled way.
    """

    async def drop_credit(
        self,
        credit_phase: CreditPhase,
        conversation_id: str | None = None,
        credit_drop_ns: int | None = None,
    ) -> None: ...

    async def publish_progress(
        self, phase: CreditPhase, sent: int, completed: int
    ) -> None: ...

    async def publish_credits_complete(self) -> None: ...

    async def publish_phase_start(
        self,
        phase: CreditPhase,
        start_ns: int,
        total_expected_requests: int | None,
        expected_duration_sec: float | None,
    ) -> None: ...

    async def publish_phase_sending_complete(
        self, phase: CreditPhase, sent_end_ns: int
    ) -> None: ...

    async def publish_phase_complete(
        self, phase: CreditPhase, completed: int, end_ns: int
    ) -> None: ...

CreditPhaseMessagesMixin

Bases: MessageBusClientMixin, CreditPhaseMessagesRequirements

Mixin for services to implement the CreditManagerProtocol.

Requirements

This mixin must be used with a class that provides: - pub_client: PubClientProtocol - service_id: str

Source code in aiperf/timing/credit_manager.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class CreditPhaseMessagesMixin(MessageBusClientMixin, CreditPhaseMessagesRequirements):
    """Mixin for services to implement the CreditManagerProtocol.

    Requirements:
        This mixin must be used with a class that provides:
        - pub_client: PubClientProtocol
        - service_id: str
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if not isinstance(self, CreditPhaseMessagesRequirements):
            raise TypeError(
                "CreditPhaseMessagesMixin must be used with a class that provides CreditPhaseMessagesRequirements"
            )

    async def publish_phase_start(
        self,
        phase: CreditPhase,
        start_ns: int,
        total_expected_requests: int | None,
        expected_duration_sec: float | None,
    ) -> None:
        """Publish the phase start message."""
        self.execute_async(
            self.publish(
                CreditPhaseStartMessage(
                    service_id=self.service_id,
                    phase=phase,
                    start_ns=start_ns,
                    # Only one of the below will be set, this is already validated in the strategy
                    total_expected_requests=total_expected_requests,
                    expected_duration_sec=expected_duration_sec,
                )
            )
        )

    async def publish_phase_sending_complete(
        self, phase: CreditPhase, sent_end_ns: int
    ) -> None:
        """Publish the phase sending complete message."""
        self.execute_async(
            self.publish(
                CreditPhaseSendingCompleteMessage(
                    service_id=self.service_id,
                    phase=phase,
                    sent_end_ns=sent_end_ns,
                )
            )
        )

    async def publish_phase_complete(
        self, phase: CreditPhase, completed: int, end_ns: int
    ) -> None:
        """Publish the phase complete message."""
        self.execute_async(
            self.publish(
                CreditPhaseCompleteMessage(
                    service_id=self.service_id,
                    phase=phase,
                    completed=completed,
                    end_ns=end_ns,
                )
            )
        )

    async def publish_progress(
        self, phase: CreditPhase, sent: int, completed: int
    ) -> None:
        """Publish the progress message."""
        self.execute_async(
            self.publish(
                CreditPhaseProgressMessage(
                    service_id=self.service_id,
                    phase=phase,
                    sent=sent,
                    completed=completed,
                )
            )
        )

    async def publish_credits_complete(self) -> None:
        """Publish the credits complete message."""
        self.debug("Publishing credits complete message")
        self.execute_async(
            self.publish(CreditsCompleteMessage(service_id=self.service_id))
        )

publish_credits_complete() async

Publish the credits complete message.

Source code in aiperf/timing/credit_manager.py
146
147
148
149
150
151
async def publish_credits_complete(self) -> None:
    """Publish the credits complete message."""
    self.debug("Publishing credits complete message")
    self.execute_async(
        self.publish(CreditsCompleteMessage(service_id=self.service_id))
    )

publish_phase_complete(phase, completed, end_ns) async

Publish the phase complete message.

Source code in aiperf/timing/credit_manager.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
async def publish_phase_complete(
    self, phase: CreditPhase, completed: int, end_ns: int
) -> None:
    """Publish the phase complete message."""
    self.execute_async(
        self.publish(
            CreditPhaseCompleteMessage(
                service_id=self.service_id,
                phase=phase,
                completed=completed,
                end_ns=end_ns,
            )
        )
    )

publish_phase_sending_complete(phase, sent_end_ns) async

Publish the phase sending complete message.

Source code in aiperf/timing/credit_manager.py
102
103
104
105
106
107
108
109
110
111
112
113
114
async def publish_phase_sending_complete(
    self, phase: CreditPhase, sent_end_ns: int
) -> None:
    """Publish the phase sending complete message."""
    self.execute_async(
        self.publish(
            CreditPhaseSendingCompleteMessage(
                service_id=self.service_id,
                phase=phase,
                sent_end_ns=sent_end_ns,
            )
        )
    )

publish_phase_start(phase, start_ns, total_expected_requests, expected_duration_sec) async

Publish the phase start message.

Source code in aiperf/timing/credit_manager.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
async def publish_phase_start(
    self,
    phase: CreditPhase,
    start_ns: int,
    total_expected_requests: int | None,
    expected_duration_sec: float | None,
) -> None:
    """Publish the phase start message."""
    self.execute_async(
        self.publish(
            CreditPhaseStartMessage(
                service_id=self.service_id,
                phase=phase,
                start_ns=start_ns,
                # Only one of the below will be set, this is already validated in the strategy
                total_expected_requests=total_expected_requests,
                expected_duration_sec=expected_duration_sec,
            )
        )
    )

publish_progress(phase, sent, completed) async

Publish the progress message.

Source code in aiperf/timing/credit_manager.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
async def publish_progress(
    self, phase: CreditPhase, sent: int, completed: int
) -> None:
    """Publish the progress message."""
    self.execute_async(
        self.publish(
            CreditPhaseProgressMessage(
                service_id=self.service_id,
                phase=phase,
                sent=sent,
                completed=completed,
            )
        )
    )

CreditPhaseMessagesRequirements

Bases: AIPerfLoggerProtocol, Protocol

Requirements for the CreditPhaseMessagesMixin. This is the list of attributes that must be provided by the class that uses this mixin.

Source code in aiperf/timing/credit_manager.py
57
58
59
60
61
62
@runtime_checkable
class CreditPhaseMessagesRequirements(AIPerfLoggerProtocol, Protocol):
    """Requirements for the CreditPhaseMessagesMixin. This is the list of attributes that must
    be provided by the class that uses this mixin."""

    service_id: str

aiperf.timing.fixed_schedule_strategy

FixedScheduleStrategy

Bases: CreditIssuingStrategy

Class for fixed schedule credit issuing strategy.

Source code in aiperf/timing/fixed_schedule_strategy.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@CreditIssuingStrategyFactory.register(TimingMode.FIXED_SCHEDULE)
class FixedScheduleStrategy(CreditIssuingStrategy):
    """
    Class for fixed schedule credit issuing strategy.
    """

    def __init__(
        self,
        config: TimingManagerConfig,
        credit_manager: CreditManagerProtocol,
        schedule: list[tuple[int, str]],
    ):
        super().__init__(config=config, credit_manager=credit_manager)

        self._schedule: list[tuple[int, str]] = schedule

    async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None:
        # TODO: Convert this code to work with the new CreditPhase logic and base classes

        if not self._schedule:
            self.warning("No schedule loaded, no credits will be dropped")
            return

        start_time_ns = time.time_ns()

        timestamp_groups = defaultdict(list)

        for timestamp, conversation_id in self._schedule:
            timestamp_groups[timestamp].append((timestamp, conversation_id))

        schedule_unique_sorted = sorted(timestamp_groups.keys())

        for unique_timestamp in schedule_unique_sorted:
            wait_duration_ns = max(0, start_time_ns + unique_timestamp - time.time_ns())
            wait_duration_sec = wait_duration_ns / 1_000_000_000

            if wait_duration_sec > 0:
                await asyncio.sleep(wait_duration_sec)

            for _, conversation_id in timestamp_groups[unique_timestamp]:
                self.execute_async(
                    self.credit_manager.drop_credit(
                        credit_phase=CreditPhase.PROFILING,
                        conversation_id=conversation_id,
                        # We already waited, so it can be sent ASAP
                        credit_drop_ns=None,
                    )
                )

        self.info("Completed all scheduled credit drops")

aiperf.timing.request_rate_strategy

RequestRateStrategy

Bases: CreditIssuingStrategy

Strategy for issuing credits based on a specified request rate.

Supports two modes: - CONSTANT: Issues credits at a constant rate with fixed intervals - POISSON: Issues credits using a Poisson process with exponentially distributed intervals

Source code in aiperf/timing/request_rate_strategy.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
@CreditIssuingStrategyFactory.register(TimingMode.REQUEST_RATE)
class RequestRateStrategy(CreditIssuingStrategy):
    """
    Strategy for issuing credits based on a specified request rate.

    Supports two modes:
    - CONSTANT: Issues credits at a constant rate with fixed intervals
    - POISSON: Issues credits using a Poisson process with exponentially distributed intervals
    """

    def __init__(
        self, config: TimingManagerConfig, credit_manager: CreditManagerProtocol
    ):
        super().__init__(config=config, credit_manager=credit_manager)

        if config.request_rate is None:
            raise InvalidStateError("Request rate is not set")
        if config.request_count < 1:
            raise InvalidStateError("Request count must be at least 1")

        self._request_rate = config.request_rate
        self._request_rate_mode = config.request_rate_mode

        # Initialize random number generator for reproducibility
        self._random = (
            random.Random(config.random_seed) if config.random_seed else random.Random()
        )

    async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None:
        """Execute a single phase. This will not return until the phase sending is complete."""
        # Issue credit drops at the specified rate
        if self._request_rate_mode == RequestRateMode.CONSTANT:
            await self._execute_constant_rate(phase_stats)
        elif self._request_rate_mode == RequestRateMode.POISSON:
            await self._execute_poisson_rate(phase_stats)
        else:
            raise InvalidStateError(
                f"Unsupported request rate mode: {self._request_rate_mode}"
            )

    async def _execute_constant_rate(self, phase_stats: CreditPhaseStats) -> None:
        """Execute credit drops at a constant rate."""

        # The effective time between each credit drop is the inverse of the request rate.
        period_sec = 1.0 / self._request_rate

        # We start by sending the first credit immediately.
        next_drop_at = time.perf_counter()

        while phase_stats.should_send:
            wait_sec = next_drop_at - time.perf_counter()
            if wait_sec > 0:
                await asyncio.sleep(wait_sec)

            self.execute_async(
                self.credit_manager.drop_credit(credit_phase=phase_stats.type)
            )
            phase_stats.sent += 1

            # Instead of naively sleeping for a constant period_sec, we are scheduling the
            # next drop to happen at exactly (next_drop_at + period_sec). This ensures that
            # we do not slowly drift over time based on slight variances in the asyncio.sleep
            # or executing the credit drop task.
            next_drop_at += period_sec

    async def _execute_poisson_rate(self, phase_stats: CreditPhaseStats) -> None:
        """Execute credit drops using Poisson process (exponential inter-arrival times).

        In a Poisson process with rate λ (requests per second), the inter-arrival times
        are exponentially distributed with parameter λ. This models realistic traffic
        patterns where requests arrive randomly but at a consistent average rate.
        """
        while phase_stats.should_send:
            # For Poisson process, inter-arrival times are exponentially distributed.
            # random.expovariate(lambd) generates exponentially distributed random numbers
            # where lambd is the rate parameter (requests per second)
            wait_duration_sec = self._random.expovariate(self._request_rate)

            if wait_duration_sec > 0:
                await asyncio.sleep(wait_duration_sec)

            self.execute_async(
                self.credit_manager.drop_credit(credit_phase=phase_stats.type)
            )
            phase_stats.sent += 1

aiperf.timing.timing_manager

TimingManager

Bases: PullClientMixin, BaseComponentService, CreditPhaseMessagesMixin

The TimingManager service is responsible to generate the schedule and issuing timing credits for requests.

Source code in aiperf/timing/timing_manager.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@implements_protocol(ServiceProtocol)
@ServiceFactory.register(ServiceType.TIMING_MANAGER)
class TimingManager(PullClientMixin, BaseComponentService, CreditPhaseMessagesMixin):
    """
    The TimingManager service is responsible to generate the schedule and issuing
    timing credits for requests.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            pull_client_address=CommAddress.CREDIT_RETURN,
            pull_client_bind=True,
        )
        self.debug("Timing manager __init__")
        self.config = TimingManagerConfig.from_user_config(self.user_config)

        self.dataset_request_client: RequestClientProtocol = (
            self.comms.create_request_client(
                CommAddress.DATASET_MANAGER_PROXY_FRONTEND,
            )
        )
        self.credit_drop_push_client: PushClientProtocol = (
            self.comms.create_push_client(
                CommAddress.CREDIT_DROP,
                bind=True,
            )
        )

        self._credit_issuing_strategy: CreditIssuingStrategy | None = None

    @on_command(CommandType.PROFILE_CONFIGURE)
    async def _profile_configure_command(
        self, message: ProfileConfigureCommand
    ) -> None:
        """Configure the timing manager."""
        self.debug(f"Configuring credit issuing strategy for {self.service_id}")

        if self.config.timing_mode == TimingMode.FIXED_SCHEDULE:
            # This will block until the dataset is ready and the timing response is received
            dataset_timing_response: DatasetTimingResponse = (
                await self.dataset_request_client.request(
                    message=DatasetTimingRequest(
                        service_id=self.service_id,
                    ),
                )
            )
            self.debug(
                lambda: f"TM: Received dataset timing response: {dataset_timing_response}"
            )
            self.info("Using fixed schedule strategy")
            self._credit_issuing_strategy = (
                CreditIssuingStrategyFactory.create_instance(
                    TimingMode.FIXED_SCHEDULE,
                    config=self.config,
                    credit_manager=self,
                    schedule=dataset_timing_response.timing_data,
                )
            )
        elif self.config.timing_mode == TimingMode.CONCURRENCY:
            self.info("Using concurrency strategy")
            self._credit_issuing_strategy = (
                CreditIssuingStrategyFactory.create_instance(
                    TimingMode.CONCURRENCY,
                    config=self.config,
                    credit_manager=self,
                )
            )
        elif self.config.timing_mode == TimingMode.REQUEST_RATE:
            self.info("Using request rate strategy")
            self._credit_issuing_strategy = (
                CreditIssuingStrategyFactory.create_instance(
                    TimingMode.REQUEST_RATE,
                    config=self.config,
                    credit_manager=self,
                )
            )

        if not self._credit_issuing_strategy:
            raise InvalidStateError("No credit issuing strategy configured")
        self.debug(
            lambda: f"Timing manager configured with credit issuing strategy: {self._credit_issuing_strategy}"
        )

    @on_command(CommandType.PROFILE_START)
    async def _on_start_profiling(self, message: CommandMessage) -> None:
        """Start the timing manager and issue credit drops according to the configured strategy."""
        self.debug("Starting profiling")

        self.debug("Waiting for timing manager to be initialized")
        await self.initialized_event.wait()
        self.debug("Timing manager initialized, starting profiling")

        if not self._credit_issuing_strategy:
            raise InvalidStateError("No credit issuing strategy configured")

        self.execute_async(self._credit_issuing_strategy.start())
        self.info(
            f"Credit issuing strategy for {self.config.timing_mode.title()} started"
        )

    @on_command(CommandType.PROFILE_CANCEL)
    async def _handle_profile_cancel_command(
        self, message: ProfileCancelCommand
    ) -> None:
        self.debug(lambda: f"Received profile cancel command: {message}")
        await self.publish(
            CommandAcknowledgedResponse.from_command_message(message, self.service_id)
        )
        if self._credit_issuing_strategy:
            await self._credit_issuing_strategy.stop()

    @on_stop
    async def _timing_manager_stop(self) -> None:
        """Stop the timing manager."""
        self.debug("Stopping timing manager")
        if self._credit_issuing_strategy:
            await self._credit_issuing_strategy.stop()
        await self.cancel_all_tasks()

    @on_pull_message(MessageType.CREDIT_RETURN)
    async def _on_credit_return(self, message: CreditReturnMessage) -> None:
        """Handle the credit return message."""
        self.debug(lambda: f"Timing manager received credit return message: {message}")
        if self._credit_issuing_strategy:
            await self._credit_issuing_strategy._on_credit_return(message)

    async def drop_credit(
        self,
        credit_phase: CreditPhase,
        conversation_id: str | None = None,
        credit_drop_ns: int | None = None,
    ) -> None:
        """Drop a credit."""
        self.execute_async(
            self.credit_drop_push_client.push(
                message=CreditDropMessage(
                    service_id=self.service_id,
                    phase=credit_phase,
                    credit_drop_ns=credit_drop_ns,
                    conversation_id=conversation_id,
                ),
            )
        )

drop_credit(credit_phase, conversation_id=None, credit_drop_ns=None) async

Drop a credit.

Source code in aiperf/timing/timing_manager.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
async def drop_credit(
    self,
    credit_phase: CreditPhase,
    conversation_id: str | None = None,
    credit_drop_ns: int | None = None,
) -> None:
    """Drop a credit."""
    self.execute_async(
        self.credit_drop_push_client.push(
            message=CreditDropMessage(
                service_id=self.service_id,
                phase=credit_phase,
                credit_drop_ns=credit_drop_ns,
                conversation_id=conversation_id,
            ),
        )
    )

main()

Main entry point for the timing manager.

Source code in aiperf/timing/timing_manager.py
204
205
206
207
208
def main() -> None:
    """Main entry point for the timing manager."""
    from aiperf.common.bootstrap import bootstrap_and_run_service

    bootstrap_and_run_service(TimingManager)

aiperf.workers.credit_processor_mixin

CreditProcessorMixin

Bases: CreditProcessorMixinRequirements

CreditProcessorMixin is a mixin that provides a method to process credit drops.

Source code in aiperf/workers/credit_processor_mixin.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class CreditProcessorMixin(CreditProcessorMixinRequirements):
    """CreditProcessorMixin is a mixin that provides a method to process credit drops."""

    def __init__(self, **kwargs):
        if not isinstance(self, CreditProcessorMixinRequirements):
            raise ValueError(
                "CreditProcessorMixin must be used with CreditProcessorMixinRequirements"
            )

    async def _process_credit_drop_internal(
        self, message: CreditDropMessage
    ) -> CreditReturnMessage:
        """Process a credit drop message. Return the CreditReturnMessage.

        - Every credit must be returned after processing
        - All results or errors should be converted to a RequestRecord and pushed to the inference results client.

        NOTE: This function MUST NOT return until the credit drop is fully processed.
        This is to ensure that the max concurrency is respected via the semaphore of the pull client.
        The way this is enforced is by requiring that this method returns a CreditReturnMessage.
        """
        # TODO: Add tests to ensure that the above note is never violated in the future

        self.trace(lambda: f"Processing credit drop: {message}")
        drop_perf_ns = time.perf_counter_ns()  # The time the credit was received

        if message.phase not in self.task_stats:
            self.task_stats[message.phase] = WorkerPhaseTaskStats()
        self.task_stats[message.phase].total += 1

        record: RequestRecord = RequestRecord()
        try:
            record = await self._execute_single_credit_internal(message)

        except Exception as e:
            self.exception(f"Error processing credit drop: {e}")
            record.error = ErrorDetails.from_exception(e)
            record.end_perf_ns = time.perf_counter_ns()

        finally:
            record.credit_phase = message.phase
            msg = InferenceResultsMessage(
                service_id=self.service_id,
                record=record,
            )

            # Note that we already ensured that the phase exists in the task_stats dict in the above code.
            if not record.valid:
                self.task_stats[message.phase].failed += 1
            else:
                self.task_stats[message.phase].completed += 1

            try:
                await self.inference_results_push_client.push(msg)
            except Exception as e:
                # If we fail to push the record, log the error and continue
                self.exception(f"Error pushing request record: {e}")
            finally:
                # Calculate the latency of the credit drop (from when the credit was dropped to when the request was sent)
                pre_inference_ns = record.start_perf_ns - drop_perf_ns
                # Always return the credits
                return_message = CreditReturnMessage(
                    service_id=self.service_id,
                    delayed_ns=record.delayed_ns,
                    pre_inference_ns=pre_inference_ns,
                    phase=message.phase,
                )
                self.trace(lambda: f"Returning credit {return_message}")
                return return_message  # noqa: B012

    async def _execute_single_credit_internal(
        self, message: CreditDropMessage
    ) -> RequestRecord:
        """Run a credit task for a single credit."""

        if not self.inference_client:
            raise NotInitializedError("Inference server client not initialized.")

        # retrieve the prompt from the dataset
        conversation_response: ConversationResponseMessage = (
            await self.conversation_request_client.request(
                ConversationRequestMessage(
                    service_id=self.service_id,
                    conversation_id=message.conversation_id,
                    credit_phase=message.phase,
                )
            )
        )
        self.trace(lambda: f"Received response message: {conversation_response}")

        if isinstance(conversation_response, ErrorMessage):
            return RequestRecord(
                model_name=self.model_endpoint.primary_model_name,
                conversation_id=message.conversation_id,
                turn_index=0,
                timestamp_ns=time.time_ns(),
                start_perf_ns=time.perf_counter_ns(),
                end_perf_ns=time.perf_counter_ns(),
                error=conversation_response.error,
            )

        record = await self._call_inference_api_internal(
            message, conversation_response.conversation.turns[0]
        )
        record.model_name = self.model_endpoint.primary_model_name
        record.conversation_id = conversation_response.conversation.session_id
        record.turn_index = 0
        return record

    async def _call_inference_api_internal(
        self,
        message: CreditDropMessage,
        turn: Turn,
    ) -> RequestRecord:
        """Make a single call to the inference API. Will return an error record if the call fails."""
        self.trace(lambda: f"Calling inference API for turn: {turn}")
        formatted_payload = None
        pre_send_perf_ns = None
        timestamp_ns = None
        try:
            # Format payload for the API request
            formatted_payload = await self.request_converter.format_payload(
                model_endpoint=self.model_endpoint,
                turn=turn,
            )

            # NOTE: Current implementation of the TimingManager bypasses this, it is for future use.
            # Wait for the credit drop time if it is in the future.
            # Note that we check this after we have retrieved the data from the dataset, to ensure
            # that we are fully ready to go.
            delayed_ns = None
            drop_ns = message.credit_drop_ns
            now_ns = time.time_ns()
            if drop_ns and drop_ns > now_ns:
                self.trace(
                    lambda: f"Waiting for credit drop expected time: {(drop_ns - now_ns) / NANOS_PER_SECOND:.2f} s"
                )
                await asyncio.sleep((drop_ns - now_ns) / NANOS_PER_SECOND)
            elif drop_ns and drop_ns < now_ns:
                delayed_ns = now_ns - drop_ns

            # Save the current perf_ns before sending the request so it can be used to calculate
            # the start_perf_ns of the request in case of an exception.
            pre_send_perf_ns = time.perf_counter_ns()
            timestamp_ns = time.time_ns()

            # Send the request to the Inference Server API and wait for the response
            result: RequestRecord = await self.inference_client.send_request(
                model_endpoint=self.model_endpoint,
                payload=formatted_payload,
            )

            self.debug(
                lambda: f"pre_send_perf_ns to start_perf_ns latency: {result.start_perf_ns - pre_send_perf_ns} ns"
            )

            result.delayed_ns = delayed_ns
            return result

        except Exception as e:
            self.exception(
                f"Error calling inference server API at {self.model_endpoint.url}: {e}"
            )
            return RequestRecord(
                request=formatted_payload,
                timestamp_ns=timestamp_ns or time.time_ns(),
                # Try and use the pre_send_perf_ns if it is available, otherwise use the current time.
                start_perf_ns=pre_send_perf_ns or time.perf_counter_ns(),
                end_perf_ns=time.perf_counter_ns(),
                error=ErrorDetails.from_exception(e),
            )

CreditProcessorMixinRequirements

Bases: AIPerfLoggerProtocol, Protocol

CreditProcessorMixinRequirements is a protocol that provides the requirements needed for the CreditProcessorMixin.

Source code in aiperf/workers/credit_processor_mixin.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@runtime_checkable
class CreditProcessorMixinRequirements(AIPerfLoggerProtocol, Protocol):
    """CreditProcessorMixinRequirements is a protocol that provides the requirements needed for the CreditProcessorMixin."""

    service_id: str
    inference_client: InferenceClientProtocol
    conversation_request_client: RequestClientProtocol
    inference_results_push_client: PushClientProtocol
    request_converter: RequestConverterProtocol
    model_endpoint: ModelEndpointInfo
    task_stats: dict[CreditPhase, WorkerPhaseTaskStats]

    async def _process_credit_drop_internal(
        self, message: CreditDropMessage
    ) -> CreditReturnMessage:
        """Process a credit drop message. Return the CreditReturnMessage."""
        ...

    async def _execute_single_credit_internal(
        self, message: CreditDropMessage
    ) -> RequestRecord:
        """Execute a single credit drop. Return the RequestRecord."""
        ...

    async def _call_inference_api_internal(
        self,
        message: CreditDropMessage,
        turn: Turn,
    ) -> RequestRecord:
        """Make a single call to the inference API. Will return an error record if the call fails."""
        ...

CreditProcessorProtocol

Bases: Protocol

CreditProcessorProtocol is a protocol that provides a method to process credit drops.

Source code in aiperf/workers/credit_processor_mixin.py
30
31
32
33
34
35
36
37
38
@runtime_checkable
class CreditProcessorProtocol(Protocol):
    """CreditProcessorProtocol is a protocol that provides a method to process credit drops."""

    async def _process_credit_drop_internal(
        self, message: CreditDropMessage
    ) -> CreditReturnMessage:
        """Process a credit drop message. Return the CreditReturnMessage."""
        ...

aiperf.workers.worker

Worker

Bases: PullClientMixin, BaseComponentService, ProcessHealthMixin, CreditProcessorMixin

Worker is primarily responsible for making API calls to the inference server. It also manages the conversation between turns and returns the results to the Inference Results Parsers.

Source code in aiperf/workers/worker.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@ServiceFactory.register(ServiceType.WORKER)
class Worker(
    PullClientMixin, BaseComponentService, ProcessHealthMixin, CreditProcessorMixin
):
    """Worker is primarily responsible for making API calls to the inference server.
    It also manages the conversation between turns and returns the results to the Inference Results Parsers.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
        **kwargs,
    ):
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            pull_client_address=CommAddress.CREDIT_DROP,
            pull_client_bind=False,
            **kwargs,
        )

        self.debug(lambda: f"Worker process __init__ (pid: {self.process.pid})")

        self.health_check_interval = self.service_config.workers.health_check_interval

        self.task_stats: dict[CreditPhase, WorkerPhaseTaskStats] = {}

        self.credit_return_push_client: PushClientProtocol = (
            self.comms.create_push_client(
                CommAddress.CREDIT_RETURN,
            )
        )
        self.inference_results_push_client: PushClientProtocol = (
            self.comms.create_push_client(
                CommAddress.RAW_INFERENCE_PROXY_FRONTEND,
            )
        )
        self.conversation_request_client: RequestClientProtocol = (
            self.comms.create_request_client(
                CommAddress.DATASET_MANAGER_PROXY_FRONTEND,
            )
        )

        self.model_endpoint = ModelEndpointInfo.from_user_config(self.user_config)

        self.debug(
            lambda: f"Creating inference client for {self.model_endpoint.endpoint.type}, "
            f"class: {InferenceClientFactory.get_class_from_type(self.model_endpoint.endpoint.type).__name__}",
        )
        self.request_converter = RequestConverterFactory.create_instance(
            self.model_endpoint.endpoint.type,
        )
        self.inference_client = InferenceClientFactory.create_instance(
            self.model_endpoint.endpoint.type,
            model_endpoint=self.model_endpoint,
        )

    @on_pull_message(MessageType.CREDIT_DROP)
    async def _credit_drop_callback(self, message: CreditDropMessage) -> None:
        """Handle an incoming credit drop message from the timing manager. Every credit must be returned after processing."""

        # Create a default credit return message in case of an exception
        credit_return_message = CreditReturnMessage(
            service_id=self.service_id,
            phase=message.phase,
        )

        try:
            # NOTE: This must be awaited to ensure that the max concurrency is respected
            credit_return_message = await self._process_credit_drop_internal(message)
        except Exception as e:
            self.exception(f"Error processing credit drop: {e}")
        finally:
            # It is fine to execute the push asynchronously here because the worker is technically
            # ready to process the next credit drop.
            self.execute_async(
                self.credit_return_push_client.push(credit_return_message)
            )

    @on_stop
    async def _shutdown_worker(self) -> None:
        self.debug("Shutting down worker")
        if self.inference_client:
            await self.inference_client.close()

    @background_task(
        immediate=False,
        interval=lambda self: self.health_check_interval,
    )
    async def _health_check_task(self) -> None:
        """Task to report the health of the worker to the worker manager."""
        await self.publish(self.create_health_message())

    def create_health_message(self) -> WorkerHealthMessage:
        return WorkerHealthMessage(
            service_id=self.service_id,
            process=self.get_process_health(),
            task_stats=self.task_stats,
        )

    @on_command(CommandType.PROFILE_CANCEL)
    async def _handle_profile_cancel_command(
        self, message: ProfileCancelCommand
    ) -> None:
        self.debug(lambda: f"Received profile cancel command: {message}")
        await self.publish(
            CommandAcknowledgedResponse.from_command_message(message, self.service_id)
        )
        await self.stop()

aiperf.workers.worker_manager

WorkerManager

Bases: BaseComponentService

The WorkerManager service is primary responsibility to manage the worker processes. It will spawn the workers, monitor their health, and stop them when the service is stopped. In the future it will also be responsible for the auto-scaling of the workers.

Source code in aiperf/workers/worker_manager.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@ServiceFactory.register(ServiceType.WORKER_MANAGER)
class WorkerManager(BaseComponentService):
    """
    The WorkerManager service is primary responsibility to manage the worker processes.
    It will spawn the workers, monitor their health, and stop them when the service is stopped.
    In the future it will also be responsible for the auto-scaling of the workers.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
        **kwargs,
    ):
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            **kwargs,
        )

        self.trace("WorkerManager.__init__")
        self.workers: dict[str, WorkerProcessInfo] = {}
        self.worker_health: dict[str, WorkerHealthMessage] = {}

        self.cpu_count = multiprocessing.cpu_count()
        self.debug(lambda: f"Detected {self.cpu_count} CPU cores/threads")

        self.max_concurrency = self.user_config.loadgen.concurrency
        self.max_workers = self.service_config.workers.max
        if self.max_workers is None:
            # Default to the number of CPU cores - 1
            self.max_workers = self.cpu_count - 1

        # Cap the worker count to the max concurrency + 1, but only if the user is in concurrency mode.
        if self.max_concurrency > 1:
            self.max_workers = min(
                self.max_concurrency + 1,
                self.max_workers,
            )

        # Ensure we have at least the min workers
        self.max_workers = max(
            self.max_workers,
            self.service_config.workers.min or 0,
        )
        self.initial_workers = self.max_workers

    @on_start
    async def _start(self) -> None:
        """Start worker manager-specific components."""
        self.debug("WorkerManager starting")

        await self.send_command_and_wait_for_response(
            SpawnWorkersCommand(
                service_id=self.service_id,
                num_workers=self.initial_workers,
                # Target the system controller directly to avoid broadcasting to all services.
                target_service_type=ServiceType.SYSTEM_CONTROLLER,
            )
        )
        self.debug("WorkerManager started")

    @on_stop
    async def _stop(self) -> None:
        self.debug("WorkerManager stopping")

        await self.publish(
            ShutdownWorkersCommand(
                service_id=self.service_id,
                all_workers=True,
                # Target the system controller directly to avoid broadcasting to all services.
                target_service_type=ServiceType.SYSTEM_CONTROLLER,
            )
        )

    @on_message(MessageType.WORKER_HEALTH)
    async def _on_worker_health(self, message: WorkerHealthMessage) -> None:
        self.debug(lambda: f"Received worker health message: {message}")
        self.worker_health[message.service_id] = message

WorkerProcessInfo

Bases: AIPerfBaseModel

Information about a worker process.

Source code in aiperf/workers/worker_manager.py
28
29
30
31
32
33
34
class WorkerProcessInfo(AIPerfBaseModel):
    """Information about a worker process."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    worker_id: str = Field(..., description="ID of the worker process")
    process: Any = Field(None, description="Process object or task")